From 4a52a71956a8d46fcb7294ac71734504bb09bcc2 Mon Sep 17 00:00:00 2001 From: S. Solomon Darnell Date: Fri, 28 Mar 2025 21:52:21 -0500 Subject: two version of R2R are here --- .../huggingface_hub/_inference_endpoints.py | 407 +++++++++++++++++++++ 1 file changed, 407 insertions(+) create mode 100644 .venv/lib/python3.12/site-packages/huggingface_hub/_inference_endpoints.py (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/_inference_endpoints.py') diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/_inference_endpoints.py b/.venv/lib/python3.12/site-packages/huggingface_hub/_inference_endpoints.py new file mode 100644 index 00000000..37733fef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/huggingface_hub/_inference_endpoints.py @@ -0,0 +1,407 @@ +import time +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING, Dict, Optional, Union + +from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError + +from .inference._client import InferenceClient +from .inference._generated._async_client import AsyncInferenceClient +from .utils import get_session, logging, parse_datetime + + +if TYPE_CHECKING: + from .hf_api import HfApi + + +logger = logging.get_logger(__name__) + + +class InferenceEndpointStatus(str, Enum): + PENDING = "pending" + INITIALIZING = "initializing" + UPDATING = "updating" + UPDATE_FAILED = "updateFailed" + RUNNING = "running" + PAUSED = "paused" + FAILED = "failed" + SCALED_TO_ZERO = "scaledToZero" + + +class InferenceEndpointType(str, Enum): + PUBlIC = "public" + PROTECTED = "protected" + PRIVATE = "private" + + +@dataclass +class InferenceEndpoint: + """ + Contains information about a deployed Inference Endpoint. + + Args: + name (`str`): + The unique name of the Inference Endpoint. + namespace (`str`): + The namespace where the Inference Endpoint is located. + repository (`str`): + The name of the model repository deployed on this Inference Endpoint. + status ([`InferenceEndpointStatus`]): + The current status of the Inference Endpoint. + url (`str`, *optional*): + The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL. + framework (`str`): + The machine learning framework used for the model. + revision (`str`): + The specific model revision deployed on the Inference Endpoint. + task (`str`): + The task associated with the deployed model. + created_at (`datetime.datetime`): + The timestamp when the Inference Endpoint was created. + updated_at (`datetime.datetime`): + The timestamp of the last update of the Inference Endpoint. + type ([`InferenceEndpointType`]): + The type of the Inference Endpoint (public, protected, private). + raw (`Dict`): + The raw dictionary data returned from the API. + token (`str` or `bool`, *optional*): + Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the + locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server. + + Example: + ```python + >>> from huggingface_hub import get_inference_endpoint + >>> endpoint = get_inference_endpoint("my-text-to-image") + >>> endpoint + InferenceEndpoint(name='my-text-to-image', ...) + + # Get status + >>> endpoint.status + 'running' + >>> endpoint.url + 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' + + # Run inference + >>> endpoint.client.text_to_image(...) + + # Pause endpoint to save $$$ + >>> endpoint.pause() + + # ... + # Resume and wait for deployment + >>> endpoint.resume() + >>> endpoint.wait() + >>> endpoint.client.text_to_image(...) + ``` + """ + + # Field in __repr__ + name: str = field(init=False) + namespace: str + repository: str = field(init=False) + status: InferenceEndpointStatus = field(init=False) + url: Optional[str] = field(init=False) + + # Other fields + framework: str = field(repr=False, init=False) + revision: str = field(repr=False, init=False) + task: str = field(repr=False, init=False) + created_at: datetime = field(repr=False, init=False) + updated_at: datetime = field(repr=False, init=False) + type: InferenceEndpointType = field(repr=False, init=False) + + # Raw dict from the API + raw: Dict = field(repr=False) + + # Internal fields + _token: Union[str, bool, None] = field(repr=False, compare=False) + _api: "HfApi" = field(repr=False, compare=False) + + @classmethod + def from_raw( + cls, raw: Dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None + ) -> "InferenceEndpoint": + """Initialize object from raw dictionary.""" + if api is None: + from .hf_api import HfApi + + api = HfApi() + if token is None: + token = api.token + + # All other fields are populated in __post_init__ + return cls(raw=raw, namespace=namespace, _token=token, _api=api) + + def __post_init__(self) -> None: + """Populate fields from raw dictionary.""" + self._populate_from_raw() + + @property + def client(self) -> InferenceClient: + """Returns a client to make predictions on this Inference Endpoint. + + Returns: + [`InferenceClient`]: an inference client pointing to the deployed endpoint. + + Raises: + [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. + """ + if self.url is None: + raise InferenceEndpointError( + "Cannot create a client for this Inference Endpoint as it is not yet deployed. " + "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." + ) + return InferenceClient( + model=self.url, + token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. + ) + + @property + def async_client(self) -> AsyncInferenceClient: + """Returns a client to make predictions on this Inference Endpoint. + + Returns: + [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint. + + Raises: + [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. + """ + if self.url is None: + raise InferenceEndpointError( + "Cannot create a client for this Inference Endpoint as it is not yet deployed. " + "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." + ) + return AsyncInferenceClient( + model=self.url, + token=self._token, # type: ignore[arg-type] # boolean token shouldn't be possible. In practice it's ok. + ) + + def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint": + """Wait for the Inference Endpoint to be deployed. + + Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout` + seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest + data. + + Args: + timeout (`int`, *optional*): + The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait + indefinitely. + refresh_every (`int`, *optional*): + The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + + Raises: + [`InferenceEndpointError`] + If the Inference Endpoint ended up in a failed state. + [`InferenceEndpointTimeoutError`] + If the Inference Endpoint is not deployed after `timeout` seconds. + """ + if timeout is not None and timeout < 0: + raise ValueError("`timeout` cannot be negative.") + if refresh_every <= 0: + raise ValueError("`refresh_every` must be positive.") + + start = time.time() + while True: + if self.status == InferenceEndpointStatus.FAILED: + raise InferenceEndpointError( + f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information." + ) + if self.status == InferenceEndpointStatus.UPDATE_FAILED: + raise InferenceEndpointError( + f"Inference Endpoint {self.name} failed to update. Please check the logs for more information." + ) + if self.status == InferenceEndpointStatus.RUNNING and self.url is not None: + # Verify the endpoint is actually reachable + response = get_session().get(self.url, headers=self._api._build_hf_headers(token=self._token)) + if response.status_code == 200: + logger.info("Inference Endpoint is ready to be used.") + return self + + if timeout is not None: + if time.time() - start > timeout: + raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") + logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...") + time.sleep(refresh_every) + self.fetch() + + def fetch(self) -> "InferenceEndpoint": + """Fetch latest information about the Inference Endpoint. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def update( + self, + *, + # Compute update + accelerator: Optional[str] = None, + instance_size: Optional[str] = None, + instance_type: Optional[str] = None, + min_replica: Optional[int] = None, + max_replica: Optional[int] = None, + scale_to_zero_timeout: Optional[int] = None, + # Model update + repository: Optional[str] = None, + framework: Optional[str] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[Dict] = None, + secrets: Optional[Dict[str, str]] = None, + ) -> "InferenceEndpoint": + """Update the Inference Endpoint. + + This method allows the update of either the compute configuration, the deployed model, or both. All arguments are + optional but at least one must be provided. + + This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Args: + accelerator (`str`, *optional*): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`, *optional*): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`, *optional*): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero. + + repository (`str`, *optional*): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`, *optional*): + The machine learning framework used for the model (e.g. `"custom"`). + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`Dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + secrets (`Dict[str, str]`, *optional*): + Secret values to inject in the container environment. + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + # Make API call + obj = self._api.update_inference_endpoint( + name=self.name, + namespace=self.namespace, + accelerator=accelerator, + instance_size=instance_size, + instance_type=instance_type, + min_replica=min_replica, + max_replica=max_replica, + scale_to_zero_timeout=scale_to_zero_timeout, + repository=repository, + framework=framework, + revision=revision, + task=task, + custom_image=custom_image, + secrets=secrets, + token=self._token, # type: ignore [arg-type] + ) + + # Mutate current object + self.raw = obj.raw + self._populate_from_raw() + return self + + def pause(self) -> "InferenceEndpoint": + """Pause the Inference Endpoint. + + A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`]. + This is different than scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which + would be automatically restarted when a request is made to it. + + This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def resume(self, running_ok: bool = True) -> "InferenceEndpoint": + """Resume the Inference Endpoint. + + This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Args: + running_ok (`bool`, *optional*): + If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to + `True`. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.resume_inference_endpoint( + name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token + ) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def scale_to_zero(self) -> "InferenceEndpoint": + """Scale Inference Endpoint to zero. + + An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a + cold start delay. This is different than pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which + would require a manual resume with [`InferenceEndpoint.resume`]. + + This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the + latest data from the server. + + Returns: + [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. + """ + obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + self.raw = obj.raw + self._populate_from_raw() + return self + + def delete(self) -> None: + """Delete the Inference Endpoint. + + This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable + to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`]. + + This is an alias for [`HfApi.delete_inference_endpoint`]. + """ + self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) # type: ignore [arg-type] + + def _populate_from_raw(self) -> None: + """Populate fields from raw dictionary. + + Called in __post_init__ + each time the Inference Endpoint is updated. + """ + # Repr fields + self.name = self.raw["name"] + self.repository = self.raw["model"]["repository"] + self.status = self.raw["status"]["state"] + self.url = self.raw["status"].get("url") + + # Other fields + self.framework = self.raw["model"]["framework"] + self.revision = self.raw["model"]["revision"] + self.task = self.raw["model"]["task"] + self.created_at = parse_datetime(self.raw["status"]["createdAt"]) + self.updated_at = parse_datetime(self.raw["status"]["updatedAt"]) + self.type = self.raw["type"] -- cgit v1.2.3