Source code for towhee.functional.mixins.serve

# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import queue
import threading
import concurrent.futures
from towhee.functional.entity import Entity
from towhee.functional.option import Some
# pylint: disable=import-outside-toplevel


class _APIWrapper:
    """API Wrapper

    Works by creating a local queue where values are added. At the same time this queue
    is consumed by a DataCollection which is created when entering the API.
    """
    tls = threading.local()

    def __init__(self, index=None, cls=None) -> None:
        self._queue = queue.Queue()
        self._cls = cls

        if index is not None:
            self._index = index if isinstance(index, list) else [index]
        else:
            self._index = index

    def feed(self, x) -> None:
        if self._index is None:
            entity = x
        else:
            if len(self._index) == 1:
                x = (x, )
            data = dict(zip(self._index, x))
            entity = Entity(**data)
        entity = Some(entity)
        self._queue.put(entity)

    def __iter__(self):
        while True:
            yield self._queue.get()

    def __enter__(self):
        _APIWrapper.tls.place_holder = self
        return self._cls(self).stream()

    def __exit__(self, exc_type, exc_value, traceback):
        if hasattr(_APIWrapper.tls, 'place_holder'):
            _APIWrapper.tls.place_holder = None


class _PipeWrapper:
    """
    Allows for the execution of a DataCollection chain as a function. Thread safe.
    """

    def __init__(self, pipe, place_holder) -> None:
        self._pipe = pipe
        self._place_holder = place_holder
        self._futures = queue.Queue()
        self._lock = threading.Lock()
        self._executor = threading.Thread(target=self.worker, daemon=True)
        self._executor.start()

    def worker(self):
        while True:
            future = self._futures.get()
            result = next(self._pipe)
            future.set_result(result)

    def execute(self, x):
        with self._lock:
            future = concurrent.futures.Future()
            self._futures.put(future)
            self._place_holder.feed(x)
        return future.result()


async def _decode_content(req):
    from multipart.multipart import parse_options_header
    content_type_header = req.headers.get('Content-Type')
    content_type, _ = parse_options_header(content_type_header)

    if content_type in {b'multipart/form-data'}:
        return await req.form()
    if content_type.startswith(b'image/'):
        return await req.body()
    return (await req.body()).decode()


[docs]class ServeMixin: """ Mixin for API serve """
[docs] def serve(self, path='/', app=None): """ Serve the DataFrame as a RESTful API Args: path (str, optional): API path. Defaults to '/'. app (_type_, optional): The FastAPI app the API bind to, will create one if None. Returns: _type_: the app that bind to Examples: >>> from fastapi import FastAPI >>> from fastapi.testclient import TestClient >>> app = FastAPI() >>> import towhee >>> with towhee.api() as api: ... app1 = ( ... api.map(lambda x: x+' -> 1') ... .map(lambda x: x+' => 1') ... .serve('/app1', app) ... ) >>> with towhee.api['x']() as api: ... app2 = ( ... api.runas_op['x', 'x_plus_1'](func=lambda x: x+' -> 2') ... .runas_op['x_plus_1', 'y'](func=lambda x: x+' => 2') ... .select['y']() ... .serve('/app2', app) ... ) >>> with towhee.api() as api: ... app2 = ( ... api.parse_json() ... .runas_op['x', 'x_plus_1'](func=lambda x: x+' -> 3') ... .runas_op['x_plus_1', 'y'](func=lambda x: x+' => 3') ... .select['y']() ... .serve('/app3', app) ... ) >>> client = TestClient(app) >>> client.post('/app1', '1').text '"1 -> 1 => 1"' >>> client.post('/app2', '2').text '{"y":"2 -> 2 => 2"}' >>> client.post('/app3', '{"x": "3"}').text '{"y":"3 -> 3 => 3"}' """ if app is None: from fastapi import FastAPI, Request app = FastAPI() else: from fastapi import Request api = _APIWrapper.tls.place_holder pipeline = _PipeWrapper(self._iterable, api) @app.post(path) async def wrapper(req: Request): nonlocal pipeline req = await _decode_content(req) rsp = pipeline.execute(req) if rsp.is_empty(): return rsp.get() return rsp.get() return app
[docs] def as_function(self): """ Make the DataFrame as callable function Returns: _type_: a callable function Examples: >>> import towhee >>> with towhee.api() as api: ... func1 = ( ... api.map(lambda x: x+' -> 1') ... .map(lambda x: x+' => 1') ... .as_function() ... ) >>> with towhee.api['x']() as api: ... func2 = ( ... api.runas_op['x', 'x_plus_1'](func=lambda x: x+' -> 2') ... .runas_op['x_plus_1', 'y'](func=lambda x: x+' => 2') ... .select['y']() ... .as_raw() ... .as_function() ... ) >>> with towhee.api() as api: ... func3 = ( ... api.parse_json() ... .runas_op['x', 'x_plus_1'](func=lambda x: x+' -> 3') ... .runas_op['x_plus_1', 'y'](func=lambda x: x+' => 3') ... .select['y']() ... .as_json() ... .as_function() ... ) >>> func1('1') '1 -> 1 => 1' >>> func2('2') '2 -> 2 => 2' >>> func3('{"x": "3"}') '{"y": "3 -> 3 => 3"}' """ api = _APIWrapper.tls.place_holder as_function_self = self pipeline = _PipeWrapper(self._iterable, api) class _Wrapper: def __init__(self): self.dag_info = as_function_self.compile_dag() self.__name__ = self.__class__.__name__ def __call__(self, req): rsp = pipeline.execute(req) if rsp.is_empty(): return rsp.get() return rsp.get() return _Wrapper()
@classmethod def api(cls, index=None): return _APIWrapper(index=index, cls=cls)