1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
|
import os
import uuid
from typing import TYPE_CHECKING
import fire
if TYPE_CHECKING:
from r2r.main.execution import R2RExecutionWrapper
class SampleDataIngestor:
USER_IDS = [
"063edaf8-3e63-4cb9-a4d6-a855f36376c3",
"45c3f5a8-bcbe-43b1-9b20-51c07fd79f14",
"c6c23d85-6217-4caa-b391-91ec0021a000",
None,
]
def __init__(
self,
executor: "R2RExecutionWrapper",
):
self.executor = executor
@staticmethod
def get_sample_files(no_media: bool = True) -> list[str]:
examples_dir = os.path.join(
os.path.dirname(os.path.abspath(__file__)), ".."
)
files = [
os.path.join(examples_dir, "data", "aristotle.txt"),
os.path.join(examples_dir, "data", "got.txt"),
os.path.join(examples_dir, "data", "screen_shot.png"),
os.path.join(examples_dir, "data", "pg_essay_1.html"),
os.path.join(examples_dir, "data", "pg_essay_2.html"),
os.path.join(examples_dir, "data", "pg_essay_3.html"),
os.path.join(examples_dir, "data", "pg_essay_4.html"),
os.path.join(examples_dir, "data", "pg_essay_5.html"),
os.path.join(examples_dir, "data", "lyft_2021.pdf"),
os.path.join(examples_dir, "data", "uber_2021.pdf"),
os.path.join(examples_dir, "data", "sample.mp3"),
os.path.join(examples_dir, "data", "sample2.mp3"),
]
if no_media:
excluded_types = ["jpeg", "jpg", "png", "svg", "mp3", "mp4"]
files = [
file_path
for file_path in files
if file_path.split(".")[-1].lower() not in excluded_types
]
return files
def ingest_sample_files(self, no_media: bool = True):
sample_files = self.get_sample_files(no_media)
user_ids = [
uuid.UUID(user_id) if user_id else None
for user_id in self.USER_IDS
]
response = self.executor.ingest_files(
sample_files,
[
{"user_id": user_ids[it % len(user_ids)]}
for it in range(len(sample_files))
],
)
return response
def ingest_sample_file(self, no_media: bool = True, option: int = 0):
sample_files = self.get_sample_files()
user_id = uuid.UUID(self.USER_IDS[option % len(self.USER_IDS)])
response = self.executor.ingest_files(
[sample_files[option]], [{"user_id": user_id}]
)
return response
if __name__ == "__main__":
fire.Fire(SampleDataIngestor)
|