aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py')
-rw-r--r--.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py217
1 files changed, 217 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py b/.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py
new file mode 100644
index 00000000..f895fcc6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py
@@ -0,0 +1,217 @@
+import io
+from typing import Any, Dict, List, Optional, Union
+
+from . import constants
+from .hf_api import HfApi
+from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args
+from .utils._deprecation import _deprecate_method
+
+
+logger = logging.get_logger(__name__)
+
+
+ALL_TASKS = [
+ # NLP
+ "text-classification",
+ "token-classification",
+ "table-question-answering",
+ "question-answering",
+ "zero-shot-classification",
+ "translation",
+ "summarization",
+ "conversational",
+ "feature-extraction",
+ "text-generation",
+ "text2text-generation",
+ "fill-mask",
+ "sentence-similarity",
+ # Audio
+ "text-to-speech",
+ "automatic-speech-recognition",
+ "audio-to-audio",
+ "audio-classification",
+ "voice-activity-detection",
+ # Computer vision
+ "image-classification",
+ "object-detection",
+ "image-segmentation",
+ "text-to-image",
+ "image-to-image",
+ # Others
+ "tabular-classification",
+ "tabular-regression",
+]
+
+
+class InferenceApi:
+ """Client to configure requests and make calls to the HuggingFace Inference API.
+
+ Example:
+
+ ```python
+ >>> from huggingface_hub.inference_api import InferenceApi
+
+ >>> # Mask-fill example
+ >>> inference = InferenceApi("bert-base-uncased")
+ >>> inference(inputs="The goal of life is [MASK].")
+ [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}]
+
+ >>> # Question Answering example
+ >>> inference = InferenceApi("deepset/roberta-base-squad2")
+ >>> inputs = {
+ ... "question": "What's my name?",
+ ... "context": "My name is Clara and I live in Berkeley.",
+ ... }
+ >>> inference(inputs)
+ {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'}
+
+ >>> # Zero-shot example
+ >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli")
+ >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
+ >>> params = {"candidate_labels": ["refund", "legal", "faq"]}
+ >>> inference(inputs, params)
+ {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}
+
+ >>> # Overriding configured task
+ >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction")
+
+ >>> # Text-to-image
+ >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1")
+ >>> inference("cat")
+ <PIL.PngImagePlugin.PngImageFile image (...)>
+
+ >>> # Return as raw response to parse the output yourself
+ >>> inference = InferenceApi("mio/amadeus")
+ >>> response = inference("hello world", raw_response=True)
+ >>> response.headers
+ {"Content-Type": "audio/flac", ...}
+ >>> response.content # raw bytes from server
+ b'(...)'
+ ```
+ """
+
+ @validate_hf_hub_args
+ @_deprecate_method(
+ version="1.0",
+ message=(
+ "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out"
+ " this guide to learn how to convert your script to use it:"
+ " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client."
+ ),
+ )
+ def __init__(
+ self,
+ repo_id: str,
+ task: Optional[str] = None,
+ token: Optional[str] = None,
+ gpu: bool = False,
+ ):
+ """Inits headers and API call information.
+
+ Args:
+ repo_id (``str``):
+ Id of repository (e.g. `user/bert-base-uncased`).
+ task (``str``, `optional`, defaults ``None``):
+ Whether to force a task instead of using task specified in the
+ repository.
+ token (`str`, `optional`):
+ The API token to use as HTTP bearer authorization. This is not
+ the authentication token. You can find the token in
+ https://huggingface.co/settings/token. Alternatively, you can
+ find both your organizations and personal API tokens using
+ `HfApi().whoami(token)`.
+ gpu (`bool`, `optional`, defaults `False`):
+ Whether to use GPU instead of CPU for inference(requires Startup
+ plan at least).
+ """
+ self.options = {"wait_for_model": True, "use_gpu": gpu}
+ self.headers = build_hf_headers(token=token)
+
+ # Configure task
+ model_info = HfApi(token=token).model_info(repo_id=repo_id)
+ if not model_info.pipeline_tag and not task:
+ raise ValueError(
+ "Task not specified in the repository. Please add it to the model card"
+ " using pipeline_tag"
+ " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)"
+ )
+
+ if task and task != model_info.pipeline_tag:
+ if task not in ALL_TASKS:
+ raise ValueError(f"Invalid task {task}. Make sure it's valid.")
+
+ logger.warning(
+ "You're using a different task than the one specified in the"
+ " repository. Be sure to know what you're doing :)"
+ )
+ self.task = task
+ else:
+ assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None"
+ self.task = model_info.pipeline_tag
+
+ self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
+
+ def __repr__(self):
+ # Do not add headers to repr to avoid leaking token.
+ return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})"
+
+ def __call__(
+ self,
+ inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
+ params: Optional[Dict] = None,
+ data: Optional[bytes] = None,
+ raw_response: bool = False,
+ ) -> Any:
+ """Make a call to the Inference API.
+
+ Args:
+ inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*):
+ Inputs for the prediction.
+ params (`Dict`, *optional*):
+ Additional parameters for the models. Will be sent as `parameters` in the
+ payload.
+ data (`bytes`, *optional*):
+ Bytes content of the request. In this case, leave `inputs` and `params` empty.
+ raw_response (`bool`, defaults to `False`):
+ If `True`, the raw `Response` object is returned. You can parse its content
+ as preferred. By default, the content is parsed into a more practical format
+ (json dictionary or PIL Image for example).
+ """
+ # Build payload
+ payload: Dict[str, Any] = {
+ "options": self.options,
+ }
+ if inputs:
+ payload["inputs"] = inputs
+ if params:
+ payload["parameters"] = params
+
+ # Make API call
+ response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data)
+
+ # Let the user handle the response
+ if raw_response:
+ return response
+
+ # By default, parse the response for the user.
+ content_type = response.headers.get("Content-Type") or ""
+ if content_type.startswith("image"):
+ if not is_pillow_available():
+ raise ImportError(
+ f"Task '{self.task}' returned as image but Pillow is not installed."
+ " Please install it (`pip install Pillow`) or pass"
+ " `raw_response=True` to get the raw `Response` object and parse"
+ " the image by yourself."
+ )
+
+ from PIL import Image
+
+ return Image.open(io.BytesIO(response.content))
+ elif content_type == "application/json":
+ return response.json()
+ else:
+ raise NotImplementedError(
+ f"{content_type} output type is not implemented yet. You can pass"
+ " `raw_response=True` to get the raw `Response` object and parse the"
+ " output by yourself."
+ )