import os import uuid from typing import TYPE_CHECKING import fire if TYPE_CHECKING: from r2r.main.execution import R2RExecutionWrapper class SampleDataIngestor: USER_IDS = [ "063edaf8-3e63-4cb9-a4d6-a855f36376c3", "45c3f5a8-bcbe-43b1-9b20-51c07fd79f14", "c6c23d85-6217-4caa-b391-91ec0021a000", None, ] def __init__( self, executor: "R2RExecutionWrapper", ): self.executor = executor @staticmethod def get_sample_files(no_media: bool = True) -> list[str]: examples_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), ".." ) files = [ os.path.join(examples_dir, "data", "aristotle.txt"), os.path.join(examples_dir, "data", "got.txt"), os.path.join(examples_dir, "data", "screen_shot.png"), os.path.join(examples_dir, "data", "pg_essay_1.html"), os.path.join(examples_dir, "data", "pg_essay_2.html"), os.path.join(examples_dir, "data", "pg_essay_3.html"), os.path.join(examples_dir, "data", "pg_essay_4.html"), os.path.join(examples_dir, "data", "pg_essay_5.html"), os.path.join(examples_dir, "data", "lyft_2021.pdf"), os.path.join(examples_dir, "data", "uber_2021.pdf"), os.path.join(examples_dir, "data", "sample.mp3"), os.path.join(examples_dir, "data", "sample2.mp3"), ] if no_media: excluded_types = ["jpeg", "jpg", "png", "svg", "mp3", "mp4"] files = [ file_path for file_path in files if file_path.split(".")[-1].lower() not in excluded_types ] return files def ingest_sample_files(self, no_media: bool = True): sample_files = self.get_sample_files(no_media) user_ids = [ uuid.UUID(user_id) if user_id else None for user_id in self.USER_IDS ] response = self.executor.ingest_files( sample_files, [ {"user_id": user_ids[it % len(user_ids)]} for it in range(len(sample_files)) ], ) return response def ingest_sample_file(self, no_media: bool = True, option: int = 0): sample_files = self.get_sample_files() user_id = uuid.UUID(self.USER_IDS[option % len(self.USER_IDS)]) response = self.executor.ingest_files( [sample_files[option]], [{"user_id": user_id}] ) return response if __name__ == "__main__": fire.Fire(SampleDataIngestor)