aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/utils/base_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/utils/base_utils.py')
-rwxr-xr-xR2R/r2r/base/utils/base_utils.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/R2R/r2r/base/utils/base_utils.py b/R2R/r2r/base/utils/base_utils.py
new file mode 100755
index 00000000..12652833
--- /dev/null
+++ b/R2R/r2r/base/utils/base_utils.py
@@ -0,0 +1,63 @@
+import asyncio
+import uuid
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable
+
+if TYPE_CHECKING:
+ from ..pipeline.base_pipeline import AsyncPipeline
+
+
+def generate_run_id() -> uuid.UUID:
+ return uuid.uuid4()
+
+
+def generate_id_from_label(label: str) -> uuid.UUID:
+ return uuid.uuid5(uuid.NAMESPACE_DNS, label)
+
+
+async def to_async_generator(
+ iterable: Iterable[Any],
+) -> AsyncGenerator[Any, None]:
+ for item in iterable:
+ yield item
+
+
+def run_pipeline(pipeline: "AsyncPipeline", input: Any, *args, **kwargs):
+ if not isinstance(input, AsyncGenerator) and not isinstance(input, list):
+ input = to_async_generator([input])
+ elif not isinstance(input, AsyncGenerator):
+ input = to_async_generator(input)
+
+ async def _run_pipeline(input, *args, **kwargs):
+ return await pipeline.run(input, *args, **kwargs)
+
+ return asyncio.run(_run_pipeline(input, *args, **kwargs))
+
+
+def increment_version(version: str) -> str:
+ prefix = version[:-1]
+ suffix = int(version[-1])
+ return f"{prefix}{suffix + 1}"
+
+
+class EntityType:
+ def __init__(self, name: str):
+ self.name = name
+
+
+class Relation:
+ def __init__(self, name: str):
+ self.name = name
+
+
+def format_entity_types(entity_types: list[EntityType]) -> str:
+ lines = []
+ for entity in entity_types:
+ lines.append(entity.name)
+ return "\n".join(lines)
+
+
+def format_relations(predicates: list[Relation]) -> str:
+ lines = []
+ for predicate in predicates:
+ lines.append(predicate.name)
+ return "\n".join(lines)