about summary refs log tree commit diff
path: root/R2R/r2r/examples/scripts/sample_data_ingestor.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/examples/scripts/sample_data_ingestor.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/examples/scripts/sample_data_ingestor.py')
-rwxr-xr-xR2R/r2r/examples/scripts/sample_data_ingestor.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/R2R/r2r/examples/scripts/sample_data_ingestor.py b/R2R/r2r/examples/scripts/sample_data_ingestor.py
new file mode 100755
index 00000000..67eecd16
--- /dev/null
+++ b/R2R/r2r/examples/scripts/sample_data_ingestor.py
@@ -0,0 +1,81 @@
+import os
+import uuid
+from typing import TYPE_CHECKING
+
+import fire
+
+if TYPE_CHECKING:
+    from r2r.main.execution import R2RExecutionWrapper
+
+
+class SampleDataIngestor:
+    USER_IDS = [
+        "063edaf8-3e63-4cb9-a4d6-a855f36376c3",
+        "45c3f5a8-bcbe-43b1-9b20-51c07fd79f14",
+        "c6c23d85-6217-4caa-b391-91ec0021a000",
+        None,
+    ]
+
+    def __init__(
+        self,
+        executor: "R2RExecutionWrapper",
+    ):
+        self.executor = executor
+
+    @staticmethod
+    def get_sample_files(no_media: bool = True) -> list[str]:
+        examples_dir = os.path.join(
+            os.path.dirname(os.path.abspath(__file__)), ".."
+        )
+
+        files = [
+            os.path.join(examples_dir, "data", "aristotle.txt"),
+            os.path.join(examples_dir, "data", "got.txt"),
+            os.path.join(examples_dir, "data", "screen_shot.png"),
+            os.path.join(examples_dir, "data", "pg_essay_1.html"),
+            os.path.join(examples_dir, "data", "pg_essay_2.html"),
+            os.path.join(examples_dir, "data", "pg_essay_3.html"),
+            os.path.join(examples_dir, "data", "pg_essay_4.html"),
+            os.path.join(examples_dir, "data", "pg_essay_5.html"),
+            os.path.join(examples_dir, "data", "lyft_2021.pdf"),
+            os.path.join(examples_dir, "data", "uber_2021.pdf"),
+            os.path.join(examples_dir, "data", "sample.mp3"),
+            os.path.join(examples_dir, "data", "sample2.mp3"),
+        ]
+        if no_media:
+            excluded_types = ["jpeg", "jpg", "png", "svg", "mp3", "mp4"]
+            files = [
+                file_path
+                for file_path in files
+                if file_path.split(".")[-1].lower() not in excluded_types
+            ]
+        return files
+
+    def ingest_sample_files(self, no_media: bool = True):
+        sample_files = self.get_sample_files(no_media)
+        user_ids = [
+            uuid.UUID(user_id) if user_id else None
+            for user_id in self.USER_IDS
+        ]
+
+        response = self.executor.ingest_files(
+            sample_files,
+            [
+                {"user_id": user_ids[it % len(user_ids)]}
+                for it in range(len(sample_files))
+            ],
+        )
+        return response
+
+    def ingest_sample_file(self, no_media: bool = True, option: int = 0):
+        sample_files = self.get_sample_files()
+        user_id = uuid.UUID(self.USER_IDS[option % len(self.USER_IDS)])
+
+        response = self.executor.ingest_files(
+            [sample_files[option]], [{"user_id": user_id}]
+        )
+        return response
+
+
+if __name__ == "__main__":
+    fire.Fire(SampleDataIngestor)