aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/examples/scripts/sample_data_ingestor.py
blob: 67eecd167776a0059b669d1fb65866446845288e (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import uuid
from typing import TYPE_CHECKING

import fire

if TYPE_CHECKING:
    from r2r.main.execution import R2RExecutionWrapper


class SampleDataIngestor:
    USER_IDS = [
        "063edaf8-3e63-4cb9-a4d6-a855f36376c3",
        "45c3f5a8-bcbe-43b1-9b20-51c07fd79f14",
        "c6c23d85-6217-4caa-b391-91ec0021a000",
        None,
    ]

    def __init__(
        self,
        executor: "R2RExecutionWrapper",
    ):
        self.executor = executor

    @staticmethod
    def get_sample_files(no_media: bool = True) -> list[str]:
        examples_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), ".."
        )

        files = [
            os.path.join(examples_dir, "data", "aristotle.txt"),
            os.path.join(examples_dir, "data", "got.txt"),
            os.path.join(examples_dir, "data", "screen_shot.png"),
            os.path.join(examples_dir, "data", "pg_essay_1.html"),
            os.path.join(examples_dir, "data", "pg_essay_2.html"),
            os.path.join(examples_dir, "data", "pg_essay_3.html"),
            os.path.join(examples_dir, "data", "pg_essay_4.html"),
            os.path.join(examples_dir, "data", "pg_essay_5.html"),
            os.path.join(examples_dir, "data", "lyft_2021.pdf"),
            os.path.join(examples_dir, "data", "uber_2021.pdf"),
            os.path.join(examples_dir, "data", "sample.mp3"),
            os.path.join(examples_dir, "data", "sample2.mp3"),
        ]
        if no_media:
            excluded_types = ["jpeg", "jpg", "png", "svg", "mp3", "mp4"]
            files = [
                file_path
                for file_path in files
                if file_path.split(".")[-1].lower() not in excluded_types
            ]
        return files

    def ingest_sample_files(self, no_media: bool = True):
        sample_files = self.get_sample_files(no_media)
        user_ids = [
            uuid.UUID(user_id) if user_id else None
            for user_id in self.USER_IDS
        ]

        response = self.executor.ingest_files(
            sample_files,
            [
                {"user_id": user_ids[it % len(user_ids)]}
                for it in range(len(sample_files))
            ],
        )
        return response

    def ingest_sample_file(self, no_media: bool = True, option: int = 0):
        sample_files = self.get_sample_files()
        user_id = uuid.UUID(self.USER_IDS[option % len(self.USER_IDS)])

        response = self.executor.ingest_files(
            [sample_files[option]], [{"user_id": user_id}]
        )
        return response


if __name__ == "__main__":
    fire.Fire(SampleDataIngestor)