diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/fal_ai.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/fal_ai.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/fal_ai.py | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/fal_ai.py b/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/fal_ai.py new file mode 100644 index 00000000..e6d2e4bd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/fal_ai.py @@ -0,0 +1,90 @@ +import base64 +from abc import ABC +from typing import Any, Dict, Optional, Union + +from huggingface_hub.inference._common import _as_dict +from huggingface_hub.inference._providers._common import TaskProviderHelper, filter_none +from huggingface_hub.utils import get_session + + +class FalAITask(TaskProviderHelper, ABC): + def __init__(self, task: str): + super().__init__(provider="fal-ai", base_url="https://fal.run", task=task) + + def _prepare_headers(self, headers: Dict, api_key: str) -> Dict: + headers = super()._prepare_headers(headers, api_key) + if not api_key.startswith("hf_"): + headers["authorization"] = f"Key {api_key}" + return headers + + def _prepare_route(self, mapped_model: str) -> str: + return f"/{mapped_model}" + + +class FalAIAutomaticSpeechRecognitionTask(FalAITask): + def __init__(self): + super().__init__("automatic-speech-recognition") + + def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + if isinstance(inputs, str) and inputs.startswith(("http://", "https://")): + # If input is a URL, pass it directly + audio_url = inputs + else: + # If input is a file path, read it first + if isinstance(inputs, str): + with open(inputs, "rb") as f: + inputs = f.read() + + audio_b64 = base64.b64encode(inputs).decode() + content_type = "audio/mpeg" + audio_url = f"data:{content_type};base64,{audio_b64}" + + return {"audio_url": audio_url, **filter_none(parameters)} + + def get_response(self, response: Union[bytes, Dict]) -> Any: + text = _as_dict(response)["text"] + if not isinstance(text, str): + raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.") + return text + + +class FalAITextToImageTask(FalAITask): + def __init__(self): + super().__init__("text-to-image") + + def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + parameters = filter_none(parameters) + if "width" in parameters and "height" in parameters: + parameters["image_size"] = { + "width": parameters.pop("width"), + "height": parameters.pop("height"), + } + return {"prompt": inputs, **parameters} + + def get_response(self, response: Union[bytes, Dict]) -> Any: + url = _as_dict(response)["images"][0]["url"] + return get_session().get(url).content + + +class FalAITextToSpeechTask(FalAITask): + def __init__(self): + super().__init__("text-to-speech") + + def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + return {"lyrics": inputs, **filter_none(parameters)} + + def get_response(self, response: Union[bytes, Dict]) -> Any: + url = _as_dict(response)["audio"]["url"] + return get_session().get(url).content + + +class FalAITextToVideoTask(FalAITask): + def __init__(self): + super().__init__("text-to-video") + + def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]: + return {"prompt": inputs, **filter_none(parameters)} + + def get_response(self, response: Union[bytes, Dict]) -> Any: + url = _as_dict(response)["video"]["url"] + return get_session().get(url).content |