aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/examples/scripts/sample_data_ingestor.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/examples/scripts/sample_data_ingestor.py')
-rwxr-xr-xR2R/r2r/examples/scripts/sample_data_ingestor.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/R2R/r2r/examples/scripts/sample_data_ingestor.py b/R2R/r2r/examples/scripts/sample_data_ingestor.py
new file mode 100755
index 00000000..67eecd16
--- /dev/null
+++ b/R2R/r2r/examples/scripts/sample_data_ingestor.py
@@ -0,0 +1,81 @@
+import os
+import uuid
+from typing import TYPE_CHECKING
+
+import fire
+
+if TYPE_CHECKING:
+ from r2r.main.execution import R2RExecutionWrapper
+
+
+class SampleDataIngestor:
+ USER_IDS = [
+ "063edaf8-3e63-4cb9-a4d6-a855f36376c3",
+ "45c3f5a8-bcbe-43b1-9b20-51c07fd79f14",
+ "c6c23d85-6217-4caa-b391-91ec0021a000",
+ None,
+ ]
+
+ def __init__(
+ self,
+ executor: "R2RExecutionWrapper",
+ ):
+ self.executor = executor
+
+ @staticmethod
+ def get_sample_files(no_media: bool = True) -> list[str]:
+ examples_dir = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), ".."
+ )
+
+ files = [
+ os.path.join(examples_dir, "data", "aristotle.txt"),
+ os.path.join(examples_dir, "data", "got.txt"),
+ os.path.join(examples_dir, "data", "screen_shot.png"),
+ os.path.join(examples_dir, "data", "pg_essay_1.html"),
+ os.path.join(examples_dir, "data", "pg_essay_2.html"),
+ os.path.join(examples_dir, "data", "pg_essay_3.html"),
+ os.path.join(examples_dir, "data", "pg_essay_4.html"),
+ os.path.join(examples_dir, "data", "pg_essay_5.html"),
+ os.path.join(examples_dir, "data", "lyft_2021.pdf"),
+ os.path.join(examples_dir, "data", "uber_2021.pdf"),
+ os.path.join(examples_dir, "data", "sample.mp3"),
+ os.path.join(examples_dir, "data", "sample2.mp3"),
+ ]
+ if no_media:
+ excluded_types = ["jpeg", "jpg", "png", "svg", "mp3", "mp4"]
+ files = [
+ file_path
+ for file_path in files
+ if file_path.split(".")[-1].lower() not in excluded_types
+ ]
+ return files
+
+ def ingest_sample_files(self, no_media: bool = True):
+ sample_files = self.get_sample_files(no_media)
+ user_ids = [
+ uuid.UUID(user_id) if user_id else None
+ for user_id in self.USER_IDS
+ ]
+
+ response = self.executor.ingest_files(
+ sample_files,
+ [
+ {"user_id": user_ids[it % len(user_ids)]}
+ for it in range(len(sample_files))
+ ],
+ )
+ return response
+
+ def ingest_sample_file(self, no_media: bool = True, option: int = 0):
+ sample_files = self.get_sample_files()
+ user_id = uuid.UUID(self.USER_IDS[option % len(self.USER_IDS)])
+
+ response = self.executor.ingest_files(
+ [sample_files[option]], [{"user_id": user_id}]
+ )
+ return response
+
+
+if __name__ == "__main__":
+ fire.Fire(SampleDataIngestor)