about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_common.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/huggingface_hub/inference/_common.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/inference/_common.py')
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/inference/_common.py422
1 files changed, 422 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_common.py b/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_common.py
new file mode 100644
index 00000000..574f726b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_common.py
@@ -0,0 +1,422 @@
+# coding=utf-8
+# Copyright 2023-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Contains utilities used by both the sync and async inference clients."""
+
+import base64
+import io
+import json
+import logging
+from contextlib import contextmanager
+from dataclasses import dataclass
+from pathlib import Path
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    AsyncIterable,
+    BinaryIO,
+    ContextManager,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Literal,
+    NoReturn,
+    Optional,
+    Union,
+    overload,
+)
+
+from requests import HTTPError
+
+from huggingface_hub.errors import (
+    GenerationError,
+    IncompleteGenerationError,
+    OverloadedError,
+    TextGenerationError,
+    UnknownError,
+    ValidationError,
+)
+
+from ..utils import get_session, is_aiohttp_available, is_numpy_available, is_pillow_available
+from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput
+
+
+if TYPE_CHECKING:
+    from aiohttp import ClientResponse, ClientSession
+    from PIL.Image import Image
+
+# TYPES
+UrlT = str
+PathT = Union[str, Path]
+BinaryT = Union[bytes, BinaryIO]
+ContentT = Union[BinaryT, PathT, UrlT]
+
+# Use to set a Accept: image/png header
+TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"}
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class RequestParameters:
+    url: str
+    task: str
+    model: Optional[str]
+    json: Optional[Union[str, Dict, List]]
+    data: Optional[ContentT]
+    headers: Dict[str, Any]
+
+
+# Add dataclass for ModelStatus. We use this dataclass in get_model_status function.
+@dataclass
+class ModelStatus:
+    """
+    This Dataclass represents the model status in the HF Inference API.
+
+    Args:
+        loaded (`bool`):
+            If the model is currently loaded into HF's Inference API. Models
+            are loaded on-demand, leading to the user's first request taking longer.
+            If a model is loaded, you can be assured that it is in a healthy state.
+        state (`str`):
+            The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'.
+            If a model's state is 'Loadable', it's not too big and has a supported
+            backend. Loadable models are automatically loaded when the user first
+            requests inference on the endpoint. This means it is transparent for the
+            user to load a model, except that the first call takes longer to complete.
+        compute_type (`Dict`):
+            Information about the compute resource the model is using or will use, such as 'gpu' type and number of
+            replicas.
+        framework (`str`):
+            The name of the framework that the model was built with, such as 'transformers'
+            or 'text-generation-inference'.
+    """
+
+    loaded: bool
+    state: str
+    compute_type: Dict
+    framework: str
+
+
+## IMPORT UTILS
+
+
+def _import_aiohttp():
+    # Make sure `aiohttp` is installed on the machine.
+    if not is_aiohttp_available():
+        raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).")
+    import aiohttp
+
+    return aiohttp
+
+
+def _import_numpy():
+    """Make sure `numpy` is installed on the machine."""
+    if not is_numpy_available():
+        raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).")
+    import numpy
+
+    return numpy
+
+
+def _import_pil_image():
+    """Make sure `PIL` is installed on the machine."""
+    if not is_pillow_available():
+        raise ImportError(
+            "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be"
+            " post-processed, use `client.post(...)` and get the raw response from the server."
+        )
+    from PIL import Image
+
+    return Image
+
+
+## ENCODING / DECODING UTILS
+
+
+@overload
+def _open_as_binary(
+    content: ContentT,
+) -> ContextManager[BinaryT]: ...  # means "if input is not None, output is not None"
+
+
+@overload
+def _open_as_binary(
+    content: Literal[None],
+) -> ContextManager[Literal[None]]: ...  # means "if input is None, output is None"
+
+
+@contextmanager  # type: ignore
+def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]:
+    """Open `content` as a binary file, either from a URL, a local path, or raw bytes.
+
+    Do nothing if `content` is None,
+
+    TODO: handle a PIL.Image as input
+    TODO: handle base64 as input
+    """
+    # If content is a string => must be either a URL or a path
+    if isinstance(content, str):
+        if content.startswith("https://") or content.startswith("http://"):
+            logger.debug(f"Downloading content from {content}")
+            yield get_session().get(content).content  # TODO: retrieve as stream and pipe to post request ?
+            return
+        content = Path(content)
+        if not content.exists():
+            raise FileNotFoundError(
+                f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local"
+                " file. To pass raw content, please encode it as bytes first."
+            )
+
+    # If content is a Path => open it
+    if isinstance(content, Path):
+        logger.debug(f"Opening content from {content}")
+        with content.open("rb") as f:
+            yield f
+    else:
+        # Otherwise: already a file-like object or None
+        yield content
+
+
+def _b64_encode(content: ContentT) -> str:
+    """Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL."""
+    with _open_as_binary(content) as data:
+        data_as_bytes = data if isinstance(data, bytes) else data.read()
+        return base64.b64encode(data_as_bytes).decode()
+
+
+def _b64_to_image(encoded_image: str) -> "Image":
+    """Parse a base64-encoded string into a PIL Image."""
+    Image = _import_pil_image()
+    return Image.open(io.BytesIO(base64.b64decode(encoded_image)))
+
+
+def _bytes_to_list(content: bytes) -> List:
+    """Parse bytes from a Response object into a Python list.
+
+    Expects the response body to be JSON-encoded data.
+
+    NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a
+    dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
+    """
+    return json.loads(content.decode())
+
+
+def _bytes_to_dict(content: bytes) -> Dict:
+    """Parse bytes from a Response object into a Python dictionary.
+
+    Expects the response body to be JSON-encoded data.
+
+    NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a
+    list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
+    """
+    return json.loads(content.decode())
+
+
+def _bytes_to_image(content: bytes) -> "Image":
+    """Parse bytes from a Response object into a PIL Image.
+
+    Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead.
+    """
+    Image = _import_pil_image()
+    return Image.open(io.BytesIO(content))
+
+
+def _as_dict(response: Union[bytes, Dict]) -> Dict:
+    return json.loads(response) if isinstance(response, bytes) else response
+
+
+## PAYLOAD UTILS
+
+
+## STREAMING UTILS
+
+
+def _stream_text_generation_response(
+    bytes_output_as_lines: Iterable[bytes], details: bool
+) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]:
+    """Used in `InferenceClient.text_generation`."""
+    # Parse ServerSentEvents
+    for byte_payload in bytes_output_as_lines:
+        try:
+            output = _format_text_generation_stream_output(byte_payload, details)
+        except StopIteration:
+            break
+        if output is not None:
+            yield output
+
+
+async def _async_stream_text_generation_response(
+    bytes_output_as_lines: AsyncIterable[bytes], details: bool
+) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]:
+    """Used in `AsyncInferenceClient.text_generation`."""
+    # Parse ServerSentEvents
+    async for byte_payload in bytes_output_as_lines:
+        try:
+            output = _format_text_generation_stream_output(byte_payload, details)
+        except StopIteration:
+            break
+        if output is not None:
+            yield output
+
+
+def _format_text_generation_stream_output(
+    byte_payload: bytes, details: bool
+) -> Optional[Union[str, TextGenerationStreamOutput]]:
+    if not byte_payload.startswith(b"data:"):
+        return None  # empty line
+
+    if byte_payload.strip() == b"data: [DONE]":
+        raise StopIteration("[DONE] signal received.")
+
+    # Decode payload
+    payload = byte_payload.decode("utf-8")
+    json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
+
+    # Either an error as being returned
+    if json_payload.get("error") is not None:
+        raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
+
+    # Or parse token payload
+    output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload)
+    return output.token.text if not details else output
+
+
+def _stream_chat_completion_response(
+    bytes_lines: Iterable[bytes],
+) -> Iterable[ChatCompletionStreamOutput]:
+    """Used in `InferenceClient.chat_completion` if model is served with TGI."""
+    for item in bytes_lines:
+        try:
+            output = _format_chat_completion_stream_output(item)
+        except StopIteration:
+            break
+        if output is not None:
+            yield output
+
+
+async def _async_stream_chat_completion_response(
+    bytes_lines: AsyncIterable[bytes],
+) -> AsyncIterable[ChatCompletionStreamOutput]:
+    """Used in `AsyncInferenceClient.chat_completion`."""
+    async for item in bytes_lines:
+        try:
+            output = _format_chat_completion_stream_output(item)
+        except StopIteration:
+            break
+        if output is not None:
+            yield output
+
+
+def _format_chat_completion_stream_output(
+    byte_payload: bytes,
+) -> Optional[ChatCompletionStreamOutput]:
+    if not byte_payload.startswith(b"data:"):
+        return None  # empty line
+
+    if byte_payload.strip() == b"data: [DONE]":
+        raise StopIteration("[DONE] signal received.")
+
+    # Decode payload
+    payload = byte_payload.decode("utf-8")
+    json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
+
+    # Either an error as being returned
+    if json_payload.get("error") is not None:
+        raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type"))
+
+    # Or parse token payload
+    return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload)
+
+
+async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]:
+    async for byte_payload in response.content:
+        yield byte_payload.strip()
+    await client.close()
+
+
+# "TGI servers" are servers running with the `text-generation-inference` backend.
+# This backend is the go-to solution to run large language models at scale. However,
+# for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference`
+# solution is still in use.
+#
+# Both approaches have very similar APIs, but not exactly the same. What we do first in
+# the `text_generation` method is to assume the model is served via TGI. If we realize
+# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the
+# default API with a warning message. When that's the case, We remember the unsupported
+# attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable.
+#
+# In addition, TGI servers have a built-in API route for chat-completion, which is not
+# available on the default API. We use this route to provide a more consistent behavior
+# when available.
+#
+# For more details, see https://github.com/huggingface/text-generation-inference and
+# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task.
+
+_UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {}
+
+
+def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None:
+    _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs)
+
+
+def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]:
+    return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])
+
+
+# TEXT GENERATION ERRORS
+# ----------------------
+# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
+# inference project (https://github.com/huggingface/text-generation-inference).
+# ----------------------
+
+
+def raise_text_generation_error(http_error: HTTPError) -> NoReturn:
+    """
+    Try to parse text-generation-inference error message and raise HTTPError in any case.
+
+    Args:
+        error (`HTTPError`):
+            The HTTPError that have been raised.
+    """
+    # Try to parse a Text Generation Inference error
+
+    try:
+        # Hacky way to retrieve payload in case of aiohttp error
+        payload = getattr(http_error, "response_error_payload", None) or http_error.response.json()
+        error = payload.get("error")
+        error_type = payload.get("error_type")
+    except Exception:  # no payload
+        raise http_error
+
+    # If error_type => more information than `hf_raise_for_status`
+    if error_type is not None:
+        exception = _parse_text_generation_error(error, error_type)
+        raise exception from http_error
+
+    # Otherwise, fallback to default error
+    raise http_error
+
+
+def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError:
+    if error_type == "generation":
+        return GenerationError(error)  # type: ignore
+    if error_type == "incomplete_generation":
+        return IncompleteGenerationError(error)  # type: ignore
+    if error_type == "overloaded":
+        return OverloadedError(error)  # type: ignore
+    if error_type == "validation":
+        return ValidationError(error)  # type: ignore
+    return UnknownError(error)  # type: ignore