diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py b/.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py new file mode 100644 index 00000000..f895fcc6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/huggingface_hub/inference_api.py @@ -0,0 +1,217 @@ +import io +from typing import Any, Dict, List, Optional, Union + +from . import constants +from .hf_api import HfApi +from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args +from .utils._deprecation import _deprecate_method + + +logger = logging.get_logger(__name__) + + +ALL_TASKS = [ + # NLP + "text-classification", + "token-classification", + "table-question-answering", + "question-answering", + "zero-shot-classification", + "translation", + "summarization", + "conversational", + "feature-extraction", + "text-generation", + "text2text-generation", + "fill-mask", + "sentence-similarity", + # Audio + "text-to-speech", + "automatic-speech-recognition", + "audio-to-audio", + "audio-classification", + "voice-activity-detection", + # Computer vision + "image-classification", + "object-detection", + "image-segmentation", + "text-to-image", + "image-to-image", + # Others + "tabular-classification", + "tabular-regression", +] + + +class InferenceApi: + """Client to configure requests and make calls to the HuggingFace Inference API. + + Example: + + ```python + >>> from huggingface_hub.inference_api import InferenceApi + + >>> # Mask-fill example + >>> inference = InferenceApi("bert-base-uncased") + >>> inference(inputs="The goal of life is [MASK].") + [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] + + >>> # Question Answering example + >>> inference = InferenceApi("deepset/roberta-base-squad2") + >>> inputs = { + ... "question": "What's my name?", + ... "context": "My name is Clara and I live in Berkeley.", + ... } + >>> inference(inputs) + {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} + + >>> # Zero-shot example + >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli") + >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" + >>> params = {"candidate_labels": ["refund", "legal", "faq"]} + >>> inference(inputs, params) + {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} + + >>> # Overriding configured task + >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction") + + >>> # Text-to-image + >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1") + >>> inference("cat") + <PIL.PngImagePlugin.PngImageFile image (...)> + + >>> # Return as raw response to parse the output yourself + >>> inference = InferenceApi("mio/amadeus") + >>> response = inference("hello world", raw_response=True) + >>> response.headers + {"Content-Type": "audio/flac", ...} + >>> response.content # raw bytes from server + b'(...)' + ``` + """ + + @validate_hf_hub_args + @_deprecate_method( + version="1.0", + message=( + "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out" + " this guide to learn how to convert your script to use it:" + " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client." + ), + ) + def __init__( + self, + repo_id: str, + task: Optional[str] = None, + token: Optional[str] = None, + gpu: bool = False, + ): + """Inits headers and API call information. + + Args: + repo_id (``str``): + Id of repository (e.g. `user/bert-base-uncased`). + task (``str``, `optional`, defaults ``None``): + Whether to force a task instead of using task specified in the + repository. + token (`str`, `optional`): + The API token to use as HTTP bearer authorization. This is not + the authentication token. You can find the token in + https://huggingface.co/settings/token. Alternatively, you can + find both your organizations and personal API tokens using + `HfApi().whoami(token)`. + gpu (`bool`, `optional`, defaults `False`): + Whether to use GPU instead of CPU for inference(requires Startup + plan at least). + """ + self.options = {"wait_for_model": True, "use_gpu": gpu} + self.headers = build_hf_headers(token=token) + + # Configure task + model_info = HfApi(token=token).model_info(repo_id=repo_id) + if not model_info.pipeline_tag and not task: + raise ValueError( + "Task not specified in the repository. Please add it to the model card" + " using pipeline_tag" + " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)" + ) + + if task and task != model_info.pipeline_tag: + if task not in ALL_TASKS: + raise ValueError(f"Invalid task {task}. Make sure it's valid.") + + logger.warning( + "You're using a different task than the one specified in the" + " repository. Be sure to know what you're doing :)" + ) + self.task = task + else: + assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None" + self.task = model_info.pipeline_tag + + self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}" + + def __repr__(self): + # Do not add headers to repr to avoid leaking token. + return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})" + + def __call__( + self, + inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None, + params: Optional[Dict] = None, + data: Optional[bytes] = None, + raw_response: bool = False, + ) -> Any: + """Make a call to the Inference API. + + Args: + inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*): + Inputs for the prediction. + params (`Dict`, *optional*): + Additional parameters for the models. Will be sent as `parameters` in the + payload. + data (`bytes`, *optional*): + Bytes content of the request. In this case, leave `inputs` and `params` empty. + raw_response (`bool`, defaults to `False`): + If `True`, the raw `Response` object is returned. You can parse its content + as preferred. By default, the content is parsed into a more practical format + (json dictionary or PIL Image for example). + """ + # Build payload + payload: Dict[str, Any] = { + "options": self.options, + } + if inputs: + payload["inputs"] = inputs + if params: + payload["parameters"] = params + + # Make API call + response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data) + + # Let the user handle the response + if raw_response: + return response + + # By default, parse the response for the user. + content_type = response.headers.get("Content-Type") or "" + if content_type.startswith("image"): + if not is_pillow_available(): + raise ImportError( + f"Task '{self.task}' returned as image but Pillow is not installed." + " Please install it (`pip install Pillow`) or pass" + " `raw_response=True` to get the raw `Response` object and parse" + " the image by yourself." + ) + + from PIL import Image + + return Image.open(io.BytesIO(response.content)) + elif content_type == "application/json": + return response.json() + else: + raise NotImplementedError( + f"{content_type} output type is not implemented yet. You can pass" + " `raw_response=True` to get the raw `Response` object and parse the" + " output by yourself." + ) |