aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/files.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/files.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/files.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/files.py334
1 files changed, 334 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/files.py b/.venv/lib/python3.12/site-packages/core/providers/database/files.py
new file mode 100644
index 00000000..dc349a7e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/files.py
@@ -0,0 +1,334 @@
+import io
+import logging
+from datetime import datetime
+from io import BytesIO
+from typing import BinaryIO, Optional
+from uuid import UUID
+from zipfile import ZipFile
+
+import asyncpg
+from fastapi import HTTPException
+
+from core.base import Handler, R2RException
+
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger()
+
+
+class PostgresFilesHandler(Handler):
+ """PostgreSQL implementation of the FileHandler."""
+
+ TABLE_NAME = "files"
+
+ connection_manager: PostgresConnectionManager
+
+ async def create_tables(self) -> None:
+ """Create the necessary tables for file storage."""
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresFilesHandler.TABLE_NAME)} (
+ document_id UUID PRIMARY KEY,
+ name TEXT NOT NULL,
+ oid OID NOT NULL,
+ size BIGINT NOT NULL,
+ type TEXT,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW()
+ );
+
+ -- Create trigger for updating the updated_at timestamp
+ CREATE OR REPLACE FUNCTION {self.project_name}.update_files_updated_at()
+ RETURNS TRIGGER AS $$
+ BEGIN
+ NEW.updated_at = CURRENT_TIMESTAMP;
+ RETURN NEW;
+ END;
+ $$ LANGUAGE plpgsql;
+
+ DROP TRIGGER IF EXISTS update_files_updated_at
+ ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)};
+
+ CREATE TRIGGER update_files_updated_at
+ BEFORE UPDATE ON {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ FOR EACH ROW
+ EXECUTE FUNCTION {self.project_name}.update_files_updated_at();
+ """
+ await self.connection_manager.execute_query(query)
+
+ async def upsert_file(
+ self,
+ document_id: UUID,
+ file_name: str,
+ file_oid: int,
+ file_size: int,
+ file_type: Optional[str] = None,
+ ) -> None:
+ """Add or update a file entry in storage."""
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ (document_id, name, oid, size, type)
+ VALUES ($1, $2, $3, $4, $5)
+ ON CONFLICT (document_id) DO UPDATE SET
+ name = EXCLUDED.name,
+ oid = EXCLUDED.oid,
+ size = EXCLUDED.size,
+ type = EXCLUDED.type,
+ updated_at = NOW();
+ """
+ await self.connection_manager.execute_query(
+ query, [document_id, file_name, file_oid, file_size, file_type]
+ )
+
+ async def store_file(
+ self,
+ document_id: UUID,
+ file_name: str,
+ file_content: io.BytesIO,
+ file_type: Optional[str] = None,
+ ) -> None:
+ """Store a new file in the database."""
+ size = file_content.getbuffer().nbytes
+
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ async with conn.transaction():
+ oid = await conn.fetchval("SELECT lo_create(0)")
+ await self._write_lobject(conn, oid, file_content)
+ await self.upsert_file(
+ document_id, file_name, oid, size, file_type
+ )
+
+ async def _write_lobject(
+ self, conn, oid: int, file_content: io.BytesIO
+ ) -> None:
+ """Write content to a large object."""
+ lobject = await conn.fetchval("SELECT lo_open($1, $2)", oid, 0x20000)
+
+ try:
+ chunk_size = 8192 # 8 KB chunks
+ while True:
+ if chunk := file_content.read(chunk_size):
+ await conn.execute(
+ "SELECT lowrite($1, $2)", lobject, chunk
+ )
+ else:
+ break
+
+ await conn.execute("SELECT lo_close($1)", lobject)
+
+ except Exception as e:
+ await conn.execute("SELECT lo_unlink($1)", oid)
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to write to large object: {e}",
+ ) from e
+
+ async def retrieve_file(
+ self, document_id: UUID
+ ) -> Optional[tuple[str, BinaryIO, int]]:
+ """Retrieve a file from storage."""
+ query = f"""
+ SELECT name, oid, size
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+
+ result = await self.connection_manager.fetchrow_query(
+ query, [document_id]
+ )
+ if not result:
+ raise R2RException(
+ status_code=404,
+ message=f"File for document {document_id} not found",
+ )
+
+ file_name, oid, size = (
+ result["name"],
+ result["oid"],
+ result["size"],
+ )
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ file_content = await self._read_lobject(conn, oid)
+ return file_name, io.BytesIO(file_content), size
+
+ async def retrieve_files_as_zip(
+ self,
+ document_ids: Optional[list[UUID]] = None,
+ start_date: Optional[datetime] = None,
+ end_date: Optional[datetime] = None,
+ ) -> tuple[str, BinaryIO, int]:
+ """Retrieve multiple files and return them as a zip file."""
+
+ query = f"""
+ SELECT document_id, name, oid, size
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE 1=1
+ """
+ params: list = []
+
+ if document_ids:
+ query += f" AND document_id = ANY(${len(params) + 1})"
+ params.append([str(doc_id) for doc_id in document_ids])
+
+ if start_date:
+ query += f" AND created_at >= ${len(params) + 1}"
+ params.append(start_date)
+
+ if end_date:
+ query += f" AND created_at <= ${len(params) + 1}"
+ params.append(end_date)
+
+ query += " ORDER BY created_at DESC"
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ raise R2RException(
+ status_code=404,
+ message="No files found matching the specified criteria",
+ )
+
+ zip_buffer = BytesIO()
+ total_size = 0
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ with ZipFile(zip_buffer, "w") as zip_file:
+ for record in results:
+ file_content = await self._read_lobject(
+ conn, record["oid"]
+ )
+
+ zip_file.writestr(record["name"], file_content)
+ total_size += record["size"]
+
+ zip_buffer.seek(0)
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ zip_filename = f"files_export_{timestamp}.zip"
+
+ return zip_filename, zip_buffer, zip_buffer.getbuffer().nbytes
+
+ async def _read_lobject(self, conn, oid: int) -> bytes:
+ """Read content from a large object."""
+ file_data = io.BytesIO()
+ chunk_size = 8192
+
+ async with conn.transaction():
+ try:
+ lo_exists = await conn.fetchval(
+ "SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_largeobject_metadata WHERE oid = $1);",
+ oid,
+ )
+ if not lo_exists:
+ raise R2RException(
+ status_code=404,
+ message=f"Large object {oid} not found.",
+ )
+
+ lobject = await conn.fetchval(
+ "SELECT lo_open($1, 262144)", oid
+ )
+
+ if lobject is None:
+ raise R2RException(
+ status_code=404,
+ message=f"Failed to open large object {oid}.",
+ )
+
+ while True:
+ chunk = await conn.fetchval(
+ "SELECT loread($1, $2)", lobject, chunk_size
+ )
+ if not chunk:
+ break
+ file_data.write(chunk)
+ except asyncpg.exceptions.UndefinedObjectError:
+ raise R2RException(
+ status_code=404,
+ message=f"Failed to read large object {oid}",
+ ) from None
+ finally:
+ await conn.execute("SELECT lo_close($1)", lobject)
+
+ return file_data.getvalue()
+
+ async def delete_file(self, document_id: UUID) -> bool:
+ """Delete a file from storage."""
+ query = f"""
+ SELECT oid FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ oid = await conn.fetchval(query, document_id)
+ if not oid:
+ raise R2RException(
+ status_code=404,
+ message=f"File for document {document_id} not found",
+ )
+
+ await self._delete_lobject(conn, oid)
+
+ delete_query = f"""
+ DELETE FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ WHERE document_id = $1
+ """
+ await conn.execute(delete_query, document_id)
+
+ return True
+
+ async def _delete_lobject(self, conn, oid: int) -> None:
+ """Delete a large object."""
+ await conn.execute("SELECT lo_unlink($1)", oid)
+
+ async def get_files_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_document_ids: Optional[list[UUID]] = None,
+ filter_file_names: Optional[list[str]] = None,
+ ) -> list[dict]:
+ """Get an overview of stored files."""
+ conditions = []
+ params: list[str | list[str] | int] = []
+ query = f"""
+ SELECT document_id, name, oid, size, type, created_at, updated_at
+ FROM {self._get_table_name(PostgresFilesHandler.TABLE_NAME)}
+ """
+
+ if filter_document_ids:
+ conditions.append(f"document_id = ANY(${len(params) + 1})")
+ params.append([str(doc_id) for doc_id in filter_document_ids])
+
+ if filter_file_names:
+ conditions.append(f"name = ANY(${len(params) + 1})")
+ params.append(filter_file_names)
+
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+
+ query += f" ORDER BY created_at DESC OFFSET ${len(params) + 1} LIMIT ${len(params) + 2}"
+ params.extend([offset, limit])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ if not results:
+ raise R2RException(
+ status_code=404,
+ message="No files found with the given filters",
+ )
+
+ return [
+ {
+ "document_id": row["document_id"],
+ "file_name": row["name"],
+ "file_oid": row["oid"],
+ "file_size": row["size"],
+ "file_type": row["type"],
+ "created_at": row["created_at"],
+ "updated_at": row["updated_at"],
+ }
+ for row in results
+ ]