Source code for towhee.functional.mixins.display

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy
from typing import Tuple

from towhee._types import Image
from towhee.types import AudioFrame
from towhee.hparam import param_scope
from towhee.functional.entity import Entity, EntityView
# pylint: disable=dangerous-default-value


[docs]def get_df_on_columns(self, index: Tuple[str]): # pragma: no cover # pylint: disable=import-outside-toplevel from towhee.utils.thirdparty.pandas_utils import pandas as pd def inner(entity: Entity): data = {} for feature in index: data[feature] = getattr(entity, feature) return data data = map(inner, self) df = pd.DataFrame(data) return df
[docs]def calc_df(df, feature: str, target: str): # pragma: no cover # pylint: disable=import-outside-toplevel from towhee.utils.thirdparty.pandas_utils import pandas as pd lst = [] df[feature] = df[feature].fillna('NULL') if target: for i in range(df[feature].nunique()): val = list(df[feature].unique())[i] lst.append([feature, # Variable val, # Value df[df[feature] == val].count()[feature], # All df[(df[feature] == val) & (df[target] == 0)].count()[feature], # Good (think: Fraud == 0) df[(df[feature] == val) & (df[target] == 1)].count()[feature]]) # Bad (think: Fraud == 1) data = pd.DataFrame(lst, columns=['Variable', 'Value', 'All', 'Good', 'Bad']) data['Share'] = data['All'] / data['All'].sum() data['Bad Rate'] = data['Bad'] / data['All'] data['Distribution Good'] = (data['All'] - data['Bad']) / (data['All'].sum() - data['Bad'].sum()) data['Distribution Bad'] = data['Bad'] / data['Bad'].sum() data['WoE'] = numpy.log(data['Distribution Good'] / data['Distribution Bad']) data = data.replace({'WoE': {numpy.inf: 0, -numpy.inf: 0}}) data['IV'] = data['WoE'] * (data['Distribution Good'] - data['Distribution Bad']) data = data.sort_values(by=['Variable', 'Value'], ascending=[True, True]) data.index = range(len(data.index)) iv = data['IV'].sum() print(f'Variable: {feature}\'s IV sum is: {iv}') return data else: for i in range(df[feature].nunique()): val = list(df[feature].unique())[i] lst.append([feature, # Variable val, # Value df[df[feature] == val].count()[feature]]) # All data = pd.DataFrame(lst, columns=['Variable', 'Value', 'All']) data['Share'] = data['All'] / data['All'].sum() return data
def _feature_summarize_callback(self): # pragma: no cover # pylint: disable=import-outside-toplevel from towhee.utils.thirdparty.pandas_utils import pandas as pd def wrapper(_: str, index, *arg, **kws): if isinstance(index, str): index = (index,) index = list(index) if arg: kws['target'], = arg target = None if 'target' in kws: target = kws.pop('target') index.append(target) df = get_df_on_columns(self, index) summarize = pd.DataFrame() if target: index.remove(target) for feature in index: data = calc_df(df, feature, target) summarize = summarize.append(data) # pylint: disable=import-outside-toplevel from towhee.utils.thirdparty import ipython_utils ipython_utils.display(summarize) return wrapper def _plot_callback(self): # pragma: no cover # pylint: disable=unused-argument def wrapper(_: str, index, *arg, **kws): if isinstance(index, str): index = (index,) df = get_df_on_columns(self, index) if 'kind' not in kws: kws.update(kind='hist') df.plot(**kws) return wrapper
[docs]class DisplayMixin: # pragma: no cover """ Mixin for displaying data. """
[docs] def __init__(self): super().__init__() self.feature_summarize = param_scope().dispatch(_feature_summarize_callback(self)) self.plot = param_scope().dispatch(_plot_callback(self))
def as_str(self): return self._factory(map(str, self._iterable))
[docs] def show(self, limit=5, header=None, tablefmt='html', formatter={}): """Print the first n lines of a DataCollection. Args: limit (int, optional): The number of lines to print. Prints all if limit is negative. Defaults to 5. header (_type_, optional): The field names. Defaults to None. tablefmt (str, optional): The format of the output, supports html, plain, grid.. Defaults to 'html'. """ # pylint: disable=protected-access contents = [x for i, x in enumerate(self) if i < limit] if all(isinstance(x, Entity) for x in contents): header = tuple(contents[0].__dict__) if not header else header data = [list(x.__dict__.values()) for x in contents] elif all(isinstance(x, EntityView) for x in contents): header = tuple(contents[0]._table.column_names) if not header else header data = [[getattr(x, n) for n in header] for x in contents] else: data = [[x] for x in contents] table_display(to_printable_table(data, header, tablefmt, formatter), tablefmt)
[docs]def table_display(table, tablefmt='html'): # pragma: no cover """_summary_ Args: table (str): Table in printable format, such as HTML. tablefmt (str, optional): The format of the output, supports html, plain, grid. Defaults to 'html'. Raises: ValueError: Unsupported table format. """ # pylint: disable=import-outside-toplevel from towhee.utils.thirdparty.ipython_utils import display, HTML if tablefmt == 'html': display(HTML(table)) elif tablefmt in ('plain', 'grid'): print(table) else: raise ValueError('unsupported table format %s' % tablefmt)
[docs]def to_printable_table(data, header=None, tablefmt='html', formatter={}): # pragma: no cover """Convert two dimensional data structure into printable table Args: data (List[List], or List[Dict]): The data filled into table. If a list of dict is given, keys are used as column names. header (Iterable[str]), optional): The names of columns defined by user. Defaults to None. tablefmt (str, optional): The format of the output, supports html, plain, grid.. Defaults to 'html'. Raises: ValueError: Unsupported table format Returns: str: The table. """ header = [] if not header else header if tablefmt == 'html': return to_html_table(data, header, formatter) elif tablefmt in ('plain', 'grid'): return to_plain_table(data, header, tablefmt) raise ValueError('unsupported table format %s' % tablefmt)
[docs]def to_plain_table(data, header, tablefmt): # pragma: no cover """Convert two dimensional data structure into plain table Args: data (List[List] or List[Dict]): The data filled into table. If a list of dict is given, keys are used as column names. header (Iterable[str]): The names of columns defined by user. tablefmt (str): The format of the output, supports plain, grid. Returns: str: The table in plain format. """ # pylint: disable=import-outside-toplevel from tabulate import tabulate tb_contents = [[_to_plain_cell(x) for x in r] for r in data] return tabulate(tb_contents, headers=header, tablefmt=tablefmt)
def _to_plain_cell(data): # pragma: no cover if isinstance(data, str): return _text_brief(data) if isinstance(data, Image): return _image_brief(data) elif isinstance(data, AudioFrame): return _audio_frame_brief(data) elif isinstance(data, numpy.ndarray): return _ndarray_brief(data) elif isinstance(data, (list, tuple)): if all(isinstance(x, str) for x in data): return _list_brief(data, _text_brief) elif all(isinstance(x, Image) for x in data): return _list_brief(data, _image_brief) elif all(isinstance(x, AudioFrame) for x in data): return _list_brief(data, _audio_frame_brief) elif all(isinstance(x, numpy.ndarray) for x in data): return _list_brief(data, _ndarray_brief) return _default_brief(data)
[docs]def to_html_table(data, header, formatter={}): # pragma: no cover """Convert two dimensional data structure into html table Args: data (List[List] or List[Dict]): The data filled into table. If a list of dict is given, keys are used as column names. header (Iterable[str]): The names of columns defined by user. Returns: str: The table in html format. """ tb_style = 'style="border-collapse: collapse;"' th_style = 'style="text-align: center; font-size: 130%; border: none;"' str_2_callback = { 'text': _text_brief, 'image': _image_to_html_cell, 'audio_frame': _audio_frame_to_html_cell, 'video_path': _video_path_to_html_cell, } trs = [] trs.append('<tr>' + ' '.join(['<th ' + th_style + '>' + x + '</th>' for x in header]) + '</tr>') to_html_callback = {} for i, field in enumerate(header): cb = formatter.get(field, None) if isinstance(cb, str): cb = str_2_callback.get(cb, None) to_html_callback[i] = cb for r in data: trs.append( '<tr>' + ' '.join([_to_html_td(x, to_html_callback.get(i, None)) for i, x in enumerate(r)]) + '</tr>' ) return '<table ' + tb_style + '>' + ' '.join(trs) + '</table>'
def _to_html_td(data, callback=None): # pragma: no cover def wrap_td_tag(content, align='center', vertical_align='center'): td_style = 'style="' \ + 'text-align: ' + align + '; ' \ + 'vertical-align: ' + vertical_align + '; ' \ + 'border-right: solid 1px #D3D3D3; ' \ + 'border-left: solid 1px #D3D3D3; ' \ + '"' return '<td ' + td_style + '>' + content + '</td>' if callback is not None: return wrap_td_tag(callback(data)) if isinstance(data, str): return wrap_td_tag(_text_brief(data)) if isinstance(data, Image): return wrap_td_tag(_image_to_html_cell(data)) elif isinstance(data, AudioFrame): return wrap_td_tag(_audio_frame_to_html_cell(data)) elif isinstance(data, numpy.ndarray): return wrap_td_tag(_ndarray_brief(data), align='left') elif isinstance(data, (list, tuple)): if all(isinstance(x, str) for x in data): return wrap_td_tag(_text_list_brief(data)) elif all(isinstance(x, Image) for x in data): return wrap_td_tag(_images_to_html_cell(data), vertical_align='top') elif all(isinstance(x, AudioFrame) for x in data): return wrap_td_tag(_audio_frames_to_html_cell(data), vertical_align='top') elif all(isinstance(x, numpy.ndarray) for x in data): return wrap_td_tag(_list_brief(data, _ndarray_brief), align='left') else: return wrap_td_tag(_list_brief(data, _default_brief), align='left') return wrap_td_tag(_default_brief(data)) def _image_to_html_cell(img, width=128, height=128): # pragma: no cover # pylint: disable=import-outside-toplevel from towhee.utils.cv2_utils import cv2 import base64 _, img_encode = cv2.imencode('.JPEG', img) src = 'src="data:image/jpeg;base64,' + base64.b64encode(img_encode).decode() + '" ' w = 'width = "' + str(width) + 'px" ' h = 'height = "' + str(height) + 'px" ' style = 'style = "float:left; padding:2px"' return '<img ' + src + w + h + style + '>' def _images_to_html_cell(imgs, width=128, height=128): # pragma: no cover return ' '.join([_image_to_html_cell(x, width, height) for x in imgs]) def _audio_frame_to_html_cell(frame, width=128, height=128): # pragma: no cover # pylint: disable=import-outside-toplevel from towhee.utils.matplotlib_utils import plt signal = frame[0, ...] fourier = numpy.fft.fft(signal) freq = numpy.fft.fftfreq(signal.shape[-1], 1) fig = plt.figure() plt.plot(freq, fourier.real) fig.canvas.draw() data = numpy.frombuffer(fig.canvas.tostring_rgb(), dtype=numpy.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) img = Image(data, 'RGB') plt.close() return _image_to_html_cell(img, width, height) def _audio_frames_to_html_cell(frames, width=128, height=128): # pragma: no cover return ' '.join([_audio_frame_to_html_cell(x, width, height) for x in frames]) def _video_path_to_html_cell(path, width=128, height=128): src = 'src="' + path + '" ' type_suffix_idx = path.rfind('.') if type_suffix_idx == -1: raise ValueError('unsupported video format %s' % path) video_type = 'type="video/' + path[path.rfind('.') + 1:].lower() + '" ' w = 'width = "' + str(width) + 'px" ' h = 'height = "' + str(height) + 'px" ' style = 'style = "float:left; padding:2px"' return '<video ' + w + h + style + 'controls><source ' + src + video_type + '></source></video>' def _ndarray_brief(array, maxlen=3): # pragma: no cover head_vals = [repr(v) for i, v in enumerate(array.flatten()) if i < maxlen] if len(array.flatten()) > maxlen: head_vals.append('...') shape = 'shape=' + repr(array.shape) return '[' + ', '.join(head_vals) + '] ' + shape def _image_brief(img): # pragma: no cover return str(img) def _audio_frame_brief(frame): # pragma: no cover return str(frame) def _text_brief(text, maxlen=128): # pragma: no cover if len(text) > maxlen: return text[:maxlen] + '...' else: return text def _text_list_brief(data, maxlen=16): # pragma: no cover head_vals = ['<br>' + _text_brief(x) + '</br>' for i, x in enumerate(data) if i < maxlen] if len(data) > maxlen: head_vals.append('<br>...</br>') return ' '.join(head_vals) def _list_brief(data, str_method, maxlen=4): # pragma: no cover head_vals = [str_method(x) for i, x in enumerate(data) if i < maxlen] if len(data) > maxlen: head_vals.append('...') return '[' + ','.join(head_vals) + ']' + ' len=' + str(len(data)) def _default_brief(data, maxlen=128): # pragma: no cover s = str(data) return s[:maxlen] + '...' if len(s) > maxlen else s