diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/__init__.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/__init__.py | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/__init__.py b/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/__init__.py new file mode 100644 index 00000000..34003125 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/huggingface_hub/inference/_providers/__init__.py @@ -0,0 +1,135 @@ +from typing import Dict, Literal + +from ._common import TaskProviderHelper +from .black_forest_labs import BlackForestLabsTextToImageTask +from .cerebras import CerebrasConversationalTask +from .cohere import CohereConversationalTask +from .fal_ai import ( + FalAIAutomaticSpeechRecognitionTask, + FalAITextToImageTask, + FalAITextToSpeechTask, + FalAITextToVideoTask, +) +from .fireworks_ai import FireworksAIConversationalTask +from .hf_inference import HFInferenceBinaryInputTask, HFInferenceConversational, HFInferenceTask +from .hyperbolic import HyperbolicTextGenerationTask, HyperbolicTextToImageTask +from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask +from .novita import NovitaConversationalTask, NovitaTextGenerationTask +from .replicate import ReplicateTask, ReplicateTextToSpeechTask +from .sambanova import SambanovaConversationalTask +from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask + + +PROVIDER_T = Literal[ + "black-forest-labs", + "cerebras", + "cohere", + "fal-ai", + "fireworks-ai", + "hf-inference", + "hyperbolic", + "nebius", + "novita", + "replicate", + "sambanova", + "together", +] + +PROVIDERS: Dict[PROVIDER_T, Dict[str, TaskProviderHelper]] = { + "black-forest-labs": { + "text-to-image": BlackForestLabsTextToImageTask(), + }, + "cerebras": { + "conversational": CerebrasConversationalTask(), + }, + "cohere": { + "conversational": CohereConversationalTask(), + }, + "fal-ai": { + "automatic-speech-recognition": FalAIAutomaticSpeechRecognitionTask(), + "text-to-image": FalAITextToImageTask(), + "text-to-speech": FalAITextToSpeechTask(), + "text-to-video": FalAITextToVideoTask(), + }, + "fireworks-ai": { + "conversational": FireworksAIConversationalTask(), + }, + "hf-inference": { + "text-to-image": HFInferenceTask("text-to-image"), + "conversational": HFInferenceConversational(), + "text-generation": HFInferenceTask("text-generation"), + "text-classification": HFInferenceTask("text-classification"), + "question-answering": HFInferenceTask("question-answering"), + "audio-classification": HFInferenceBinaryInputTask("audio-classification"), + "automatic-speech-recognition": HFInferenceBinaryInputTask("automatic-speech-recognition"), + "fill-mask": HFInferenceTask("fill-mask"), + "feature-extraction": HFInferenceTask("feature-extraction"), + "image-classification": HFInferenceBinaryInputTask("image-classification"), + "image-segmentation": HFInferenceBinaryInputTask("image-segmentation"), + "document-question-answering": HFInferenceTask("document-question-answering"), + "image-to-text": HFInferenceBinaryInputTask("image-to-text"), + "object-detection": HFInferenceBinaryInputTask("object-detection"), + "audio-to-audio": HFInferenceBinaryInputTask("audio-to-audio"), + "zero-shot-image-classification": HFInferenceBinaryInputTask("zero-shot-image-classification"), + "zero-shot-classification": HFInferenceTask("zero-shot-classification"), + "image-to-image": HFInferenceBinaryInputTask("image-to-image"), + "sentence-similarity": HFInferenceTask("sentence-similarity"), + "table-question-answering": HFInferenceTask("table-question-answering"), + "tabular-classification": HFInferenceTask("tabular-classification"), + "text-to-speech": HFInferenceTask("text-to-speech"), + "token-classification": HFInferenceTask("token-classification"), + "translation": HFInferenceTask("translation"), + "summarization": HFInferenceTask("summarization"), + "visual-question-answering": HFInferenceBinaryInputTask("visual-question-answering"), + }, + "hyperbolic": { + "text-to-image": HyperbolicTextToImageTask(), + "conversational": HyperbolicTextGenerationTask("conversational"), + "text-generation": HyperbolicTextGenerationTask("text-generation"), + }, + "nebius": { + "text-to-image": NebiusTextToImageTask(), + "conversational": NebiusConversationalTask(), + "text-generation": NebiusTextGenerationTask(), + }, + "novita": { + "text-generation": NovitaTextGenerationTask(), + "conversational": NovitaConversationalTask(), + }, + "replicate": { + "text-to-image": ReplicateTask("text-to-image"), + "text-to-speech": ReplicateTextToSpeechTask(), + "text-to-video": ReplicateTask("text-to-video"), + }, + "sambanova": { + "conversational": SambanovaConversationalTask(), + }, + "together": { + "text-to-image": TogetherTextToImageTask(), + "conversational": TogetherConversationalTask(), + "text-generation": TogetherTextGenerationTask(), + }, +} + + +def get_provider_helper(provider: PROVIDER_T, task: str) -> TaskProviderHelper: + """Get provider helper instance by name and task. + + Args: + provider (str): Name of the provider + task (str): Name of the task + + Returns: + TaskProviderHelper: Helper instance for the specified provider and task + + Raises: + ValueError: If provider or task is not supported + """ + if provider not in PROVIDERS: + raise ValueError(f"Provider '{provider}' not supported. Available providers: {list(PROVIDERS.keys())}") + if task not in PROVIDERS[provider]: + raise ValueError( + f"Task '{task}' not supported for provider '{provider}'. " + f"Available tasks: {list(PROVIDERS[provider].keys())}" + ) + return PROVIDERS[provider][task] |