aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/utils/base_utils.py
blob: 126528331cbfb306eb8cc821b3f7a7708577f881 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import asyncio
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable

if TYPE_CHECKING:
    from ..pipeline.base_pipeline import AsyncPipeline


def generate_run_id() -> uuid.UUID:
    return uuid.uuid4()


def generate_id_from_label(label: str) -> uuid.UUID:
    return uuid.uuid5(uuid.NAMESPACE_DNS, label)


async def to_async_generator(
    iterable: Iterable[Any],
) -> AsyncGenerator[Any, None]:
    for item in iterable:
        yield item


def run_pipeline(pipeline: "AsyncPipeline", input: Any, *args, **kwargs):
    if not isinstance(input, AsyncGenerator) and not isinstance(input, list):
        input = to_async_generator([input])
    elif not isinstance(input, AsyncGenerator):
        input = to_async_generator(input)

    async def _run_pipeline(input, *args, **kwargs):
        return await pipeline.run(input, *args, **kwargs)

    return asyncio.run(_run_pipeline(input, *args, **kwargs))


def increment_version(version: str) -> str:
    prefix = version[:-1]
    suffix = int(version[-1])
    return f"{prefix}{suffix + 1}"


class EntityType:
    def __init__(self, name: str):
        self.name = name


class Relation:
    def __init__(self, name: str):
        self.name = name


def format_entity_types(entity_types: list[EntityType]) -> str:
    lines = []
    for entity in entity_types:
        lines.append(entity.name)
    return "\n".join(lines)


def format_relations(predicates: list[Relation]) -> str:
    lines = []
    for predicate in predicates:
        lines.append(predicate.name)
    return "\n".join(lines)