diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/utils')
4 files changed, 2832 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/utils/__init__.py b/.venv/lib/python3.12/site-packages/shared/utils/__init__.py new file mode 100644 index 00000000..eb037e22 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/utils/__init__.py @@ -0,0 +1,46 @@ +from .base_utils import ( + _decorate_vector_type, + _get_vector_column_str, + decrement_version, + deep_update, + dump_collector, + dump_obj, + format_search_results_for_llm, + generate_default_prompt_id, + generate_default_user_collection_id, + generate_document_id, + generate_entity_document_id, + generate_extraction_id, + generate_id, + generate_user_id, + increment_version, + validate_uuid, + yield_sse_event, +) +from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter + +__all__ = [ + "format_search_results_for_llm", + # ID generation + "generate_id", + "generate_document_id", + "generate_extraction_id", + "generate_default_user_collection_id", + "generate_user_id", + "generate_default_prompt_id", + "generate_entity_document_id", + # Other + "increment_version", + "decrement_version", + "validate_uuid", + "deep_update", + # Text splitter + "RecursiveCharacterTextSplitter", + "TextSplitter", + # Vector utils + "_decorate_vector_type", + "_get_vector_column_str", + "yield_sse_event", + "dump_collector", + "dump_obj", +] diff --git a/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py new file mode 100644 index 00000000..1864d0b4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/utils/base_utils.py @@ -0,0 +1,783 @@ +import json +import logging +import math +import uuid +from abc import ABCMeta +from copy import deepcopy +from datetime import datetime +from typing import TYPE_CHECKING, Any, Optional, Tuple, TypeVar +from uuid import NAMESPACE_DNS, UUID, uuid4, uuid5 + +import tiktoken + +from ..abstractions import ( + AggregateSearchResult, + AsyncSyncMeta, + GraphCommunityResult, + GraphEntityResult, + GraphRelationshipResult, +) +from ..abstractions.vector import VectorQuantizationType + +if TYPE_CHECKING: + pass + + +logger = logging.getLogger() + + +def id_to_shorthand(id: str | UUID): + return str(id)[:7] + + +def format_search_results_for_llm( + results: AggregateSearchResult, + collector: Any, # SearchResultsCollector +) -> str: + """ + Instead of resetting 'source_counter' to 1, we: + - For each chunk / graph / web / doc in `results`, + - Find the aggregator index from the collector, + - Print 'Source [X]:' with that aggregator index. + """ + lines = [] + + # We'll build a quick helper to locate aggregator indices for each object: + # Or you can rely on the fact that we've added them to the collector + # in the same order. But let's do a "lookup aggregator index" approach: + + # 1) Chunk search + if results.chunk_search_results: + lines.append("Vector Search Results:") + for c in results.chunk_search_results: + lines.append(f"Source ID [{id_to_shorthand(c.id)}]:") + lines.append(c.text or "") # or c.text[:200] to truncate + + # 2) Graph search + if results.graph_search_results: + lines.append("Graph Search Results:") + for g in results.graph_search_results: + lines.append(f"Source ID [{id_to_shorthand(g.id)}]:") + if isinstance(g.content, GraphCommunityResult): + lines.append(f"Community Name: {g.content.name}") + lines.append(f"ID: {g.content.id}") + lines.append(f"Summary: {g.content.summary}") + # etc. ... + elif isinstance(g.content, GraphEntityResult): + lines.append(f"Entity Name: {g.content.name}") + lines.append(f"Description: {g.content.description}") + elif isinstance(g.content, GraphRelationshipResult): + lines.append( + f"Relationship: {g.content.subject}-{g.content.predicate}-{g.content.object}" + ) + # Add metadata if needed + + # 3) Web search + if results.web_search_results: + lines.append("Web Search Results:") + for w in results.web_search_results: + lines.append(f"Source ID [{id_to_shorthand(w.id)}]:") + lines.append(f"Title: {w.title}") + lines.append(f"Link: {w.link}") + lines.append(f"Snippet: {w.snippet}") + + # 4) Local context docs + if results.document_search_results: + lines.append("Local Context Documents:") + for doc_result in results.document_search_results: + doc_title = doc_result.title or "Untitled Document" + doc_id = doc_result.id + summary = doc_result.summary + + lines.append(f"Full Document ID: {doc_id}") + lines.append(f"Shortened Document ID: {id_to_shorthand(doc_id)}") + lines.append(f"Document Title: {doc_title}") + if summary: + lines.append(f"Summary: {summary}") + + if doc_result.chunks: + # Then each chunk inside: + for chunk in doc_result.chunks: + lines.append( + f"\nChunk ID {id_to_shorthand(chunk['id'])}:\n{chunk['text']}" + ) + + result = "\n".join(lines) + return result + + +def _generate_id_from_label(label) -> UUID: + return uuid5(NAMESPACE_DNS, label) + + +def generate_id(label: Optional[str] = None) -> UUID: + """Generates a unique run id.""" + return _generate_id_from_label( + label if label is not None else str(uuid4()) + ) + + +def generate_document_id(filename: str, user_id: UUID) -> UUID: + """Generates a unique document id from a given filename and user id.""" + safe_filename = filename.replace("/", "_") + return _generate_id_from_label(f"{safe_filename}-{str(user_id)}") + + +def generate_extraction_id( + document_id: UUID, iteration: int = 0, version: str = "0" +) -> UUID: + """Generates a unique extraction id from a given document id and + iteration.""" + return _generate_id_from_label(f"{str(document_id)}-{iteration}-{version}") + + +def generate_default_user_collection_id(user_id: UUID) -> UUID: + """Generates a unique collection id from a given user id.""" + return _generate_id_from_label(str(user_id)) + + +def generate_user_id(email: str) -> UUID: + """Generates a unique user id from a given email.""" + return _generate_id_from_label(email) + + +def generate_default_prompt_id(prompt_name: str) -> UUID: + """Generates a unique prompt id.""" + return _generate_id_from_label(prompt_name) + + +def generate_entity_document_id() -> UUID: + """Generates a unique document id inserting entities into a graph.""" + generation_time = datetime.now().isoformat() + return _generate_id_from_label(f"entity-{generation_time}") + + +def increment_version(version: str) -> str: + prefix = version[:-1] + suffix = int(version[-1]) + return f"{prefix}{suffix + 1}" + + +def decrement_version(version: str) -> str: + prefix = version[:-1] + suffix = int(version[-1]) + return f"{prefix}{max(0, suffix - 1)}" + + +def validate_uuid(uuid_str: str) -> UUID: + return UUID(uuid_str) + + +def update_settings_from_dict(server_settings, settings_dict: dict): + """Updates a settings object with values from a dictionary.""" + settings = deepcopy(server_settings) + for key, value in settings_dict.items(): + if value is not None: + if isinstance(value, dict): + for k, v in value.items(): + if isinstance(getattr(settings, key), dict): + getattr(settings, key)[k] = v + else: + setattr(getattr(settings, key), k, v) + else: + setattr(settings, key, value) + + return settings + + +def _decorate_vector_type( + input_str: str, + quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, +) -> str: + return f"{quantization_type.db_type}{input_str}" + + +def _get_vector_column_str( + dimension: int | float, quantization_type: VectorQuantizationType +) -> str: + """Returns a string representation of a vector column type. + + Explicitly handles the case where the dimension is not a valid number meant + to support embedding models that do not allow for specifying the dimension. + """ + if math.isnan(dimension) or dimension <= 0: + vector_dim = "" # Allows for Postgres to handle any dimension + else: + vector_dim = f"({dimension})" + return _decorate_vector_type(vector_dim, quantization_type) + + +KeyType = TypeVar("KeyType") + + +def deep_update( + mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any] +) -> dict[KeyType, Any]: + """ + Taken from Pydantic v1: + https://github.com/pydantic/pydantic/blob/fd2991fe6a73819b48c906e3c3274e8e47d0f761/pydantic/utils.py#L200 + """ + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for k, v in updating_mapping.items(): + if ( + k in updated_mapping + and isinstance(updated_mapping[k], dict) + and isinstance(v, dict) + ): + updated_mapping[k] = deep_update(updated_mapping[k], v) + else: + updated_mapping[k] = v + return updated_mapping + + +def tokens_count_for_message(message, encoding): + """Return the number of tokens used by a single message.""" + tokens_per_message = 3 + + num_tokens = 0 + num_tokens += tokens_per_message + if message.get("function_call"): + num_tokens += len(encoding.encode(message["function_call"]["name"])) + num_tokens += len( + encoding.encode(message["function_call"]["arguments"]) + ) + elif message.get("tool_calls"): + for tool_call in message["tool_calls"]: + num_tokens += len(encoding.encode(tool_call["function"]["name"])) + num_tokens += len( + encoding.encode(tool_call["function"]["arguments"]) + ) + else: + if "content" in message: + num_tokens += len(encoding.encode(message["content"])) + + return num_tokens + + +def num_tokens_from_messages(messages, model="gpt-4o"): + """Return the number of tokens used by a list of messages for both user and assistant.""" + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + + tokens = 0 + for message_ in messages: + tokens += tokens_count_for_message(message_, encoding) + + tokens += 3 # every reply is primed with assistant + return tokens + + +class SearchResultsCollector: + """ + Collects search results in the form (source_type, result_obj). + Handles both object-oriented and dictionary-based search results. + """ + + def __init__(self): + # We'll store a list of (source_type, result_obj) + self._results_in_order = [] + + @property + def results(self): + """Get the results list""" + return self._results_in_order + + @results.setter + def results(self, value): + """ + Set the results directly, with automatic type detection for 'unknown' items + Handles the format: [('unknown', {...}), ('unknown', {...})] + """ + self._results_in_order = [] + + if isinstance(value, list): + for item in value: + if isinstance(item, tuple) and len(item) == 2: + source_type, result_obj = item + + # Only auto-detect if the source type is "unknown" + if source_type == "unknown": + detected_type = self._detect_result_type(result_obj) + self._results_in_order.append( + (detected_type, result_obj) + ) + else: + self._results_in_order.append( + (source_type, result_obj) + ) + else: + # If not a tuple, detect and add + detected_type = self._detect_result_type(item) + self._results_in_order.append((detected_type, item)) + else: + raise ValueError("Results must be a list") + + def add_aggregate_result(self, agg): + """ + Flatten the chunk_search_results, graph_search_results, web_search_results, + and document_search_results into the collector, including nested chunks. + """ + if hasattr(agg, "chunk_search_results") and agg.chunk_search_results: + for c in agg.chunk_search_results: + self._results_in_order.append(("chunk", c)) + + if hasattr(agg, "graph_search_results") and agg.graph_search_results: + for g in agg.graph_search_results: + self._results_in_order.append(("graph", g)) + + if hasattr(agg, "web_search_results") and agg.web_search_results: + for w in agg.web_search_results: + self._results_in_order.append(("web", w)) + + # Add documents and extract their chunks + if ( + hasattr(agg, "document_search_results") + and agg.document_search_results + ): + for doc in agg.document_search_results: + # Add the document itself + self._results_in_order.append(("doc", doc)) + + # Extract and add chunks from the document + chunks = None + if isinstance(doc, dict): + chunks = doc.get("chunks", []) + elif hasattr(doc, "chunks") and doc.chunks is not None: + chunks = doc.chunks + + if chunks: + for chunk in chunks: + # Ensure each chunk has the minimum required attributes + if isinstance(chunk, dict) and "id" in chunk: + # Add the chunk directly to results for citation lookup + self._results_in_order.append(("chunk", chunk)) + elif hasattr(chunk, "id"): + self._results_in_order.append(("chunk", chunk)) + + def add_result(self, result_obj, source_type=None): + """ + Add a single result object to the collector. + If source_type is not provided, automatically detect the type. + """ + if source_type: + self._results_in_order.append((source_type, result_obj)) + return source_type + + detected_type = self._detect_result_type(result_obj) + self._results_in_order.append((detected_type, result_obj)) + return detected_type + + def _detect_result_type(self, obj): + """ + Detect the type of a result object based on its properties. + Works with both object attributes and dictionary keys. + """ + # Handle dictionary types first (common for web search results) + if isinstance(obj, dict): + # Web search pattern + if all(k in obj for k in ["title", "link"]) and any( + k in obj for k in ["snippet", "description"] + ): + return "web" + + # Check for graph dictionary patterns + if "content" in obj and isinstance(obj["content"], dict): + content = obj["content"] + if all(k in content for k in ["name", "description"]): + return "graph" # Entity + if all( + k in content for k in ["subject", "predicate", "object"] + ): + return "graph" # Relationship + if all(k in content for k in ["name", "summary"]): + return "graph" # Community + + # Chunk pattern + if all(k in obj for k in ["text", "id"]) and any( + k in obj for k in ["score", "metadata"] + ): + return "chunk" + + # Context document pattern + if "document" in obj and "chunks" in obj: + return "doc" + + # Check for explicit type indicator + if "type" in obj: + type_val = str(obj["type"]).lower() + if any(t in type_val for t in ["web", "organic"]): + return "web" + if "graph" in type_val: + return "graph" + if "chunk" in type_val: + return "chunk" + if "document" in type_val: + return "doc" + + # Handle object attributes for OOP-style results + if hasattr(obj, "result_type"): + result_type = str(obj.result_type).lower() + if result_type in ["entity", "relationship", "community"]: + return "graph" + + # Check class name hints + class_name = obj.__class__.__name__ + if "Graph" in class_name: + return "graph" + if "Chunk" in class_name: + return "chunk" + if "Web" in class_name: + return "web" + if "Document" in class_name: + return "doc" + + # Check for object attribute patterns + if hasattr(obj, "content"): + content = obj.content + if hasattr(content, "name") and hasattr(content, "description"): + return "graph" # Entity + if hasattr(content, "subject") and hasattr(content, "predicate"): + return "graph" # Relationship + if hasattr(content, "name") and hasattr(content, "summary"): + return "graph" # Community + + if ( + hasattr(obj, "text") + and hasattr(obj, "id") + and (hasattr(obj, "score") or hasattr(obj, "metadata")) + ): + return "chunk" + + if ( + hasattr(obj, "title") + and hasattr(obj, "link") + and hasattr(obj, "snippet") + ): + return "web" + + if hasattr(obj, "document") and hasattr(obj, "chunks"): + return "doc" + + # Default when type can't be determined + return "unknown" + + def find_by_short_id(self, short_id): + """Find a result by its short ID prefix with better chunk handling""" + if not short_id: + return None + + # First try direct lookup using regular iteration + for _, result_obj in self._results_in_order: + # Check dictionary objects + if isinstance(result_obj, dict) and "id" in result_obj: + result_id = str(result_obj["id"]) + if result_id.startswith(short_id): + return result_obj + + # Check object with id attribute + elif hasattr(result_obj, "id"): + obj_id = getattr(result_obj, "id", None) + if obj_id and str(obj_id).startswith(short_id): + # Convert to dict if possible + if hasattr(result_obj, "as_dict"): + return result_obj.as_dict() + elif hasattr(result_obj, "model_dump"): + return result_obj.model_dump() + elif hasattr(result_obj, "dict"): + return result_obj.dict() + else: + return result_obj + + # If not found, look for chunks inside documents that weren't extracted properly + for source_type, result_obj in self._results_in_order: + if source_type == "doc": + # Try various ways to access chunks + chunks = None + if isinstance(result_obj, dict) and "chunks" in result_obj: + chunks = result_obj["chunks"] + elif ( + hasattr(result_obj, "chunks") + and result_obj.chunks is not None + ): + chunks = result_obj.chunks + + if chunks: + for chunk in chunks: + # Try each chunk + chunk_id = None + if isinstance(chunk, dict) and "id" in chunk: + chunk_id = chunk["id"] + elif hasattr(chunk, "id"): + chunk_id = chunk.id + + if chunk_id and str(chunk_id).startswith(short_id): + return chunk + + return None + + def get_results_by_type(self, type_name): + """Get all results of a specific type""" + return [ + result_obj + for source_type, result_obj in self._results_in_order + if source_type == type_name + ] + + def __repr__(self): + """String representation showing counts by type""" + type_counts = {} + for source_type, _ in self._results_in_order: + type_counts[source_type] = type_counts.get(source_type, 0) + 1 + + return f"SearchResultsCollector with {len(self._results_in_order)} results: {type_counts}" + + def get_all_results(self) -> list[Tuple[str, Any]]: + """ + Return list of (source_type, result_obj, aggregator_index), + in the order appended. + """ + return self._results_in_order + + +def convert_nonserializable_objects(obj): + if hasattr(obj, "model_dump"): + obj = obj.model_dump() + if hasattr(obj, "as_dict"): + obj = obj.as_dict() + if hasattr(obj, "to_dict"): + obj = obj.to_dict() + + if isinstance(obj, dict): + new_obj = {} + for key, value in obj.items(): + # Convert key to string if it is a UUID or not already a string. + new_key = str(key) if not isinstance(key, str) else key + new_obj[new_key] = convert_nonserializable_objects(value) + return new_obj + elif isinstance(obj, list): + return [convert_nonserializable_objects(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(convert_nonserializable_objects(item) for item in obj) + elif isinstance(obj, set): + return {convert_nonserializable_objects(item) for item in obj} + elif isinstance(obj, uuid.UUID): + return str(obj) + elif isinstance(obj, datetime): + return obj.isoformat() # Convert datetime to ISO formatted string + else: + return obj + + +def dump_obj(obj) -> list[dict[str, Any]]: + if hasattr(obj, "model_dump"): + obj = obj.model_dump() + elif hasattr(obj, "dict"): + obj = obj.dict() + elif hasattr(obj, "as_dict"): + obj = obj.as_dict() + elif hasattr(obj, "to_dict"): + obj = obj.to_dict() + obj = convert_nonserializable_objects(obj) + + return obj + + +def dump_collector(collector: SearchResultsCollector) -> list[dict[str, Any]]: + dumped = [] + for source_type, result_obj in collector.get_all_results(): + # Get the dictionary from the result object + if hasattr(result_obj, "model_dump"): + result_dict = result_obj.model_dump() + elif hasattr(result_obj, "dict"): + result_dict = result_obj.dict() + elif hasattr(result_obj, "as_dict"): + result_dict = result_obj.as_dict() + elif hasattr(result_obj, "to_dict"): + result_dict = result_obj.to_dict() + else: + result_dict = ( + result_obj # Fallback if no conversion method is available + ) + + # Use the recursive conversion on the entire dictionary + result_dict = convert_nonserializable_objects(result_dict) + + dumped.append( + { + "source_type": source_type, + "result": result_dict, + } + ) + return dumped + + +def num_tokens(text, model="gpt-4o"): + try: + encoding = tiktoken.encoding_for_model(model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + """Return the number of tokens used by a list of messages for both user and assistant.""" + return len(encoding.encode(text, disallowed_special=())) + + +class CombinedMeta(AsyncSyncMeta, ABCMeta): + pass + + +async def yield_sse_event(event_name: str, payload: dict, chunk_size=1024): + """ + Helper that yields a single SSE event in properly chunked lines. + + e.g. event: event_name + data: (partial JSON 1) + data: (partial JSON 2) + ... + [blank line to end event] + """ + + # SSE: first the "event: ..." + yield f"event: {event_name}\n" + + # Convert payload to JSON + content_str = json.dumps(payload, default=str) + + # data + yield f"data: {content_str}\n" + + # blank line signals end of SSE event + yield "\n" + + +class SSEFormatter: + """ + Enhanced formatter for Server-Sent Events (SSE) with citation tracking. + Extends the existing SSEFormatter with improved citation handling. + """ + + @staticmethod + async def yield_citation_event( + citation_data: dict, + ): + """ + Emits a citation event with optimized payload. + + Args: + citation_id: The short ID of the citation (e.g., 'abc1234') + span: (start, end) position tuple for this occurrence + payload: Source object (included only for first occurrence) + is_new: Whether this is the first time we've seen this citation + citation_id_counter: Optional counter for citation occurrences + + Yields: + Formatted SSE event lines + """ + + # Include the full payload only for new citations + if not citation_data.get("is_new") or "payload" not in citation_data: + citation_data["payload"] = None + + # Yield the event + async for line in yield_sse_event("citation", citation_data): + yield line + + @staticmethod + async def yield_final_answer_event( + final_data: dict, + ): + # Yield the event + async for line in yield_sse_event("final_answer", final_data): + yield line + + # Include other existing SSEFormatter methods for compatibility + @staticmethod + async def yield_message_event(text_segment, msg_id=None): + msg_id = msg_id or f"msg_{uuid.uuid4().hex[:8]}" + msg_payload = { + "id": msg_id, + "object": "agent.message.delta", + "delta": { + "content": [ + { + "type": "text", + "payload": { + "value": text_segment, + "annotations": [], + }, + } + ] + }, + } + async for line in yield_sse_event("message", msg_payload): + yield line + + @staticmethod + async def yield_thinking_event(text_segment, thinking_id=None): + thinking_id = thinking_id or f"think_{uuid.uuid4().hex[:8]}" + thinking_data = { + "id": thinking_id, + "object": "agent.thinking.delta", + "delta": { + "content": [ + { + "type": "text", + "payload": { + "value": text_segment, + "annotations": [], + }, + } + ] + }, + } + async for line in yield_sse_event("thinking", thinking_data): + yield line + + @staticmethod + def yield_done_event(): + return "event: done\ndata: [DONE]\n\n" + + @staticmethod + async def yield_error_event(error_message, error_id=None): + error_id = error_id or f"err_{uuid.uuid4().hex[:8]}" + error_payload = { + "id": error_id, + "object": "agent.error", + "error": {"message": error_message, "type": "agent_error"}, + } + async for line in yield_sse_event("error", error_payload): + yield line + + @staticmethod + async def yield_tool_call_event(tool_call_data): + from ..api.models.retrieval.responses import ToolCallEvent + + tc_event = ToolCallEvent(event="tool_call", data=tool_call_data) + async for line in yield_sse_event( + "tool_call", tc_event.dict()["data"] + ): + yield line + + # New helper for emitting search results: + @staticmethod + async def yield_search_results_event(aggregated_results): + payload = { + "id": "search_1", + "object": "rag.search_results", + "data": aggregated_results.as_dict(), + } + async for line in yield_sse_event("search_results", payload): + yield line + + @staticmethod + async def yield_tool_result_event(tool_result_data): + from ..api.models.retrieval.responses import ToolResultEvent + + tr_event = ToolResultEvent(event="tool_result", data=tool_result_data) + async for line in yield_sse_event( + "tool_result", tr_event.dict()["data"] + ): + yield line diff --git a/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py b/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py new file mode 100644 index 00000000..07a9f554 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/utils/splitter/__init__.py @@ -0,0 +1,3 @@ +from .text import RecursiveCharacterTextSplitter + +__all__ = ["RecursiveCharacterTextSplitter"] diff --git a/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py b/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py new file mode 100644 index 00000000..92a7c81b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/utils/splitter/text.py @@ -0,0 +1,2000 @@ +# Source - LangChain +# URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851 +"""**Text Splitters** are classes for splitting text. + +**Class hierarchy:** + +.. code-block:: + + BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter + RecursiveCharacterTextSplitter --> <name>TextSplitter + +Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter. + + +**Main helpers:** + +.. code-block:: + + Document, Tokenizer, Language, LineType, HeaderType +""" # noqa: E501 + +from __future__ import annotations + +import copy +import json +import logging +import pathlib +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from io import BytesIO, StringIO +from typing import ( + AbstractSet, + Any, + Callable, + Collection, + Iterable, + Literal, + Optional, + Sequence, + Tuple, + Type, + TypedDict, + TypeVar, + cast, +) + +import requests +from pydantic import BaseModel, Field, PrivateAttr +from typing_extensions import NotRequired + +logger = logging.getLogger() + +TS = TypeVar("TS", bound="TextSplitter") + + +class BaseSerialized(TypedDict): + """Base class for serialized objects.""" + + lc: int + id: list[str] + name: NotRequired[str] + graph: NotRequired[dict[str, Any]] + + +class SerializedConstructor(BaseSerialized): + """Serialized constructor.""" + + type: Literal["constructor"] + kwargs: dict[str, Any] + + +class SerializedSecret(BaseSerialized): + """Serialized secret.""" + + type: Literal["secret"] + + +class SerializedNotImplemented(BaseSerialized): + """Serialized not implemented.""" + + type: Literal["not_implemented"] + repr: Optional[str] + + +def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: + """Try to determine if a value is different from the default. + + Args: + value: The value. + key: The key. + model: The model. + + Returns: + Whether the value is different from the default. + """ + try: + return model.__fields__[key].get_default() != value + except Exception: + return True + + +class Serializable(BaseModel, ABC): + """Serializable base class.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Is this class serializable?""" + return False + + @classmethod + def get_lc_namespace(cls) -> list[str]: + """Get the namespace of the langchain object. + + For example, if the class is `langchain.llms.openai.OpenAI`, then the + namespace is ["langchain", "llms", "openai"] + """ + return cls.__module__.split(".") + + @property + def lc_secrets(self) -> dict[str, str]: + """A map of constructor argument names to secret ids. + + For example, {"openai_api_key": "OPENAI_API_KEY"} + """ + return {} + + @property + def lc_attributes(self) -> dict: + """List of attribute names that should be included in the serialized + kwargs. + + These attributes must be accepted by the constructor. + """ + return {} + + @classmethod + def lc_id(cls) -> list[str]: + """A unique identifier for this class for serialization purposes. + + The unique identifier is a list of strings that describes the path to + the object. + """ + return [*cls.get_lc_namespace(), cls.__name__] + + class Config: + extra = "ignore" + + def __repr_args__(self) -> Any: + return [ + (k, v) + for k, v in super().__repr_args__() + if (k not in self.__fields__ or try_neq_default(v, k, self)) + ] + + _lc_kwargs: dict[str, Any] = PrivateAttr(default_factory=dict) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lc_kwargs = kwargs + + def to_json( + self, + ) -> SerializedConstructor | SerializedNotImplemented: + if not self.is_lc_serializable(): + return self.to_json_not_implemented() + + secrets = dict() + # Get latest values for kwargs if there is an attribute with same name + lc_kwargs = { + k: getattr(self, k, v) + for k, v in self._lc_kwargs.items() + if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore + } + + # Merge the lc_secrets and lc_attributes from every class in the MRO + for cls in [None, *self.__class__.mro()]: + # Once we get to Serializable, we're done + if cls is Serializable: + break + + if cls: + deprecated_attributes = [ + "lc_namespace", + "lc_serializable", + ] + + for attr in deprecated_attributes: + if hasattr(cls, attr): + raise ValueError( + f"Class {self.__class__} has a deprecated " + f"attribute {attr}. Please use the corresponding " + f"classmethod instead." + ) + + # Get a reference to self bound to each class in the MRO + this = cast( + Serializable, self if cls is None else super(cls, self) + ) + + secrets.update(this.lc_secrets) + # Now also add the aliases for the secrets + # This ensures known secret aliases are hidden. + # Note: this does NOT hide any other extra kwargs + # that are not present in the fields. + for key in list(secrets): + value = secrets[key] + if key in this.__fields__: + secrets[this.__fields__[key].alias] = value # type: ignore + lc_kwargs.update(this.lc_attributes) + + # include all secrets, even if not specified in kwargs + # as these secrets may be passed as an environment variable instead + for key in secrets.keys(): + secret_value = getattr(self, key, None) or lc_kwargs.get(key) + if secret_value is not None: + lc_kwargs.update({key: secret_value}) + + return { + "lc": 1, + "type": "constructor", + "id": self.lc_id(), + "kwargs": ( + lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets) + ), + } + + def to_json_not_implemented(self) -> SerializedNotImplemented: + return to_json_not_implemented(self) + + +def _replace_secrets( + root: dict[Any, Any], secrets_map: dict[str, str] +) -> dict[Any, Any]: + result = root.copy() + for path, secret_id in secrets_map.items(): + [*parts, last] = path.split(".") + current = result + for part in parts: + if part not in current: + break + current[part] = current[part].copy() + current = current[part] + if last in current: + current[last] = { + "lc": 1, + "type": "secret", + "id": [secret_id], + } + return result + + +def to_json_not_implemented(obj: object) -> SerializedNotImplemented: + """Serialize a "not implemented" object. + + Args: + obj: object to serialize + + Returns: + SerializedNotImplemented + """ + _id: list[str] = [] + try: + if hasattr(obj, "__name__"): + _id = [*obj.__module__.split("."), obj.__name__] + elif hasattr(obj, "__class__"): + _id = [ + *obj.__class__.__module__.split("."), + obj.__class__.__name__, + ] + except Exception: + pass + + result: SerializedNotImplemented = { + "lc": 1, + "type": "not_implemented", + "id": _id, + "repr": None, + } + try: + result["repr"] = repr(obj) + except Exception: + pass + return result + + +class SplitterDocument(Serializable): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + """String text.""" + metadata: dict = Field(default_factory=dict) + """Arbitrary metadata about the page content (e.g., source, relationships + to other documents, etc.).""" + type: Literal["Document"] = "Document" + + def __init__(self, page_content: str, **kwargs: Any) -> None: + """Pass page_content in as positional or named arg.""" + super().__init__(page_content=page_content, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + @classmethod + def get_lc_namespace(cls) -> list[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "schema", "document"] + + +class BaseDocumentTransformer(ABC): + """Abstract base class for document transformation systems. + + A document transformation system takes a sequence of Documents and returns a + sequence of transformed Documents. + + Example: + .. code-block:: python + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + embeddings: Embeddings + similarity_fn: Callable = cosine_similarity + similarity_threshold: float = 0.95 + + class Config: + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + """ # noqa: E501 + + @abstractmethod + def transform_documents( + self, documents: Sequence[SplitterDocument], **kwargs: Any + ) -> Sequence[SplitterDocument]: + """Transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + + async def atransform_documents( + self, documents: Sequence[SplitterDocument], **kwargs: Any + ) -> Sequence[SplitterDocument]: + """Asynchronously transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + raise NotImplementedError("This method is not implemented.") + # return await langchain_core.runnables.config.run_in_executor( + # None, self.transform_documents, documents, **kwargs + # ) + + +def _make_spacy_pipe_for_splitting( + pipe: str, *, max_length: int = 1_000_000 +) -> Any: # avoid importing spacy + try: + import spacy + except ImportError: + raise ImportError( + "Spacy is not installed, run `pip install spacy`." + ) from None + if pipe == "sentencizer": + from spacy.lang.en import English + + sentencizer = English() + sentencizer.add_pipe("sentencizer") + else: + sentencizer = spacy.load(pipe, exclude=["ner", "tagger"]) + sentencizer.max_length = max_length + return sentencizer + + +def _split_text_with_regex( + text: str, separator: str, keep_separator: bool +) -> list[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = [ + _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2) + ] + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class TextSplitter(BaseDocumentTransformer, ABC): + """Interface for splitting text into chunks.""" + + def __init__( + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, + strip_whitespace: bool = True, + ) -> None: + """Create a new TextSplitter. + + Args: + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + length_function: Function that measures the length of given chunks + keep_separator: Whether to keep the separator in the chunks + add_start_index: If `True`, includes chunk's start index in + metadata + strip_whitespace: If `True`, strips whitespace from the start and + end of every document + """ + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._add_start_index = add_start_index + self._strip_whitespace = strip_whitespace + + @abstractmethod + def split_text(self, text: str) -> list[str]: + """Split text into multiple components.""" + + def create_documents( + self, texts: list[str], metadatas: Optional[list[dict]] = None + ) -> list[SplitterDocument]: + """Create documents from a list of texts.""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + index = 0 + previous_chunk_len = 0 + for chunk in self.split_text(text): + metadata = copy.deepcopy(_metadatas[i]) + if self._add_start_index: + offset = index + previous_chunk_len - self._chunk_overlap + index = text.find(chunk, max(0, offset)) + metadata["start_index"] = index + previous_chunk_len = len(chunk) + new_doc = SplitterDocument( + page_content=chunk, metadata=metadata + ) + documents.append(new_doc) + return documents + + def split_documents( + self, documents: Iterable[SplitterDocument] + ) -> list[SplitterDocument]: + """Split documents.""" + texts, metadatas = [], [] + for doc in documents: + texts.append(doc.page_content) + metadatas.append(doc.metadata) + return self.create_documents(texts, metadatas=metadatas) + + def _join_docs(self, docs: list[str], separator: str) -> Optional[str]: + text = separator.join(docs) + if self._strip_whitespace: + text = text.strip() + if text == "": + return None + else: + return text + + def _merge_splits( + self, splits: Iterable[str], separator: str + ) -> list[str]: + # We now want to combine these smaller pieces into medium size + # chunks to send to the LLM. + separator_len = self._length_function(separator) + + docs = [] + current_doc: list[str] = [] + total = 0 + for d in splits: + _len = self._length_function(d) + if ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + ): + if total > self._chunk_size: + logger.warning( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap or ( + total + + _len + + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + and total > 0 + ): + total -= self._length_function(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) + current_doc = current_doc[1:] + current_doc.append(d) + total += _len + (separator_len if len(current_doc) > 1 else 0) + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + return docs + + @classmethod + def from_huggingface_tokenizer( + cls, tokenizer: Any, **kwargs: Any + ) -> TextSplitter: + """Text splitter that uses HuggingFace tokenizer to count length.""" + try: + from transformers import PreTrainedTokenizerBase + + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + "Tokenizer received was not an instance of PreTrainedTokenizerBase" + ) + + def _huggingface_tokenizer_length(text: str) -> int: + return len(tokenizer.encode(text)) + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) from None + return cls(length_function=_huggingface_tokenizer_length, **kwargs) + + @classmethod + def from_tiktoken_encoder( + cls: Type[TS], + encoding_name: str = "gpt2", + model: Optional[str] = None, + allowed_special: Literal["all"] | AbstractSet[str] = set(), + disallowed_special: Literal["all"] | Collection[str] = "all", + **kwargs: Any, + ) -> TS: + """Text splitter that uses tiktoken encoder to count length.""" + try: + import tiktoken + except ImportError: + raise ImportError("""Could not import tiktoken python package. + This is needed in order to calculate max_tokens_for_prompt. + Please install it with `pip install tiktoken`.""") from None + + if model is not None: + enc = tiktoken.encoding_for_model(model) + else: + enc = tiktoken.get_encoding(encoding_name) + + def _tiktoken_encoder(text: str) -> int: + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + + if issubclass(cls, TokenTextSplitter): + extra_kwargs = { + "encoding_name": encoding_name, + "model": model, + "allowed_special": allowed_special, + "disallowed_special": disallowed_special, + } + kwargs = {**kwargs, **extra_kwargs} + + return cls(length_function=_tiktoken_encoder, **kwargs) + + def transform_documents( + self, documents: Sequence[SplitterDocument], **kwargs: Any + ) -> Sequence[SplitterDocument]: + """Transform sequence of documents by splitting them.""" + return self.split_documents(list(documents)) + + +class CharacterTextSplitter(TextSplitter): + """Splitting text that looks at characters.""" + + DEFAULT_SEPARATOR: str = "\n\n" + + def __init__( + self, + separator: str = DEFAULT_SEPARATOR, + is_separator_regex: bool = False, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separator = separator + self._is_separator_regex = is_separator_regex + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + separator = ( + self._separator + if self._is_separator_regex + else re.escape(self._separator) + ) + splits = _split_text_with_regex(text, separator, self._keep_separator) + _separator = "" if self._keep_separator else self._separator + return self._merge_splits(splits, _separator) + + +class LineType(TypedDict): + """Line type as typed dict.""" + + metadata: dict[str, str] + content: str + + +class HeaderType(TypedDict): + """Header type as typed dict.""" + + level: int + name: str + data: str + + +class MarkdownHeaderTextSplitter: + """Splitting markdown files based on specified headers.""" + + def __init__( + self, + headers_to_split_on: list[Tuple[str, str]], + return_each_line: bool = False, + strip_headers: bool = True, + ): + """Create a new MarkdownHeaderTextSplitter. + + Args: + headers_to_split_on: Headers we want to track + return_each_line: Return each line w/ associated headers + strip_headers: Strip split headers from the content of the chunk + """ + # Output line-by-line or aggregated into chunks w/ common headers + self.return_each_line = return_each_line + # Given the headers we want to split on, + # (e.g., "#, ##, etc") order by length + self.headers_to_split_on = sorted( + headers_to_split_on, key=lambda split: len(split[0]), reverse=True + ) + # Strip headers split headers from the content of the chunk + self.strip_headers = strip_headers + + def aggregate_lines_to_chunks( + self, lines: list[LineType] + ) -> list[SplitterDocument]: + """Combine lines with common metadata into chunks + Args: + lines: Line of text / associated header metadata + """ + aggregated_chunks: list[LineType] = [] + + for line in lines: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == line["metadata"] + ): + # If the last line in the aggregated list + # has the same metadata as the current line, + # append the current content to the last lines's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + elif ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] != line["metadata"] + # may be issues if other metadata is present + and len(aggregated_chunks[-1]["metadata"]) + < len(line["metadata"]) + and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#" + and not self.strip_headers + ): + # If the last line in the aggregated list + # has different metadata as the current line, + # and has shallower header level than the current line, + # and the last line is a header, + # and we are not stripping headers, + # append the current content to the last line's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + # and update the last line's metadata + aggregated_chunks[-1]["metadata"] = line["metadata"] + else: + # Otherwise, append the current line to the aggregated list + aggregated_chunks.append(line) + + return [ + SplitterDocument( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in aggregated_chunks + ] + + def split_text(self, text: str) -> list[SplitterDocument]: + """Split markdown file + Args: + text: Markdown file""" + + # Split the input text by newline character ("\n"). + lines = text.split("\n") + # Final output + lines_with_metadata: list[LineType] = [] + # Content and metadata of the chunk currently being processed + current_content: list[str] = [] + current_metadata: dict[str, str] = {} + # Keep track of the nested header structure + # header_stack: list[dict[str, int | str]] = [] + header_stack: list[HeaderType] = [] + initial_metadata: dict[str, str] = {} + + in_code_block = False + opening_fence = "" + + for line in lines: + stripped_line = line.strip() + + if not in_code_block: + # Exclude inline code spans + if ( + stripped_line.startswith("```") + and stripped_line.count("```") == 1 + ): + in_code_block = True + opening_fence = "```" + elif stripped_line.startswith("~~~"): + in_code_block = True + opening_fence = "~~~" + else: + if stripped_line.startswith(opening_fence): + in_code_block = False + opening_fence = "" + + if in_code_block: + current_content.append(stripped_line) + continue + + # Check each line against each of the header types (e.g., #, ##) + for sep, name in self.headers_to_split_on: + # Check if line starts with a header that we intend to split on + if stripped_line.startswith(sep) and ( + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) + or stripped_line[len(sep)] == " " + ): + # Ensure we are tracking the header as metadata + if name is not None: + # Get the current header level + current_header_level = sep.count("#") + + # Pop out headers of lower or same level from the stack + while ( + header_stack + and header_stack[-1]["level"] + >= current_header_level + ): + # We have encountered a new header + # at the same or higher level + popped_header = header_stack.pop() + # Clear the metadata for the + # popped header in initial_metadata + if popped_header["name"] in initial_metadata: + initial_metadata.pop(popped_header["name"]) + + # Push the current header to the stack + header: HeaderType = { + "level": current_header_level, + "name": name, + "data": stripped_line[len(sep) :].strip(), + } + header_stack.append(header) + # Update initial_metadata with the current header + initial_metadata[name] = header["data"] + + # Add the previous line to the lines_with_metadata + # only if current_content is not empty + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + if not self.strip_headers: + current_content.append(stripped_line) + + break + else: + if stripped_line: + current_content.append(stripped_line) + elif current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + current_metadata = initial_metadata.copy() + + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata, + } + ) + + # lines_with_metadata has each line with associated header metadata + # aggregate these into chunks based on common metadata + if not self.return_each_line: + return self.aggregate_lines_to_chunks(lines_with_metadata) + else: + return [ + SplitterDocument( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in lines_with_metadata + ] + + +class ElementType(TypedDict): + """Element type as typed dict.""" + + url: str + xpath: str + content: str + metadata: dict[str, str] + + +class HTMLHeaderTextSplitter: + """Splitting HTML files based on specified headers. + + Requires lxml package. + """ + + def __init__( + self, + headers_to_split_on: list[Tuple[str, str]], + return_each_element: bool = False, + ): + """Create a new HTMLHeaderTextSplitter. + + Args: + headers_to_split_on: list of tuples of headers we want to track + mapped to (arbitrary) keys for metadata. Allowed header values: + h1, h2, h3, h4, h5, h6 + e.g. [("h1", "Header 1"), ("h2", "Header 2)]. + return_each_element: Return each element w/ associated headers. + """ + # Output element-by-element or aggregated into chunks w/ common headers + self.return_each_element = return_each_element + self.headers_to_split_on = sorted(headers_to_split_on) + + def aggregate_elements_to_chunks( + self, elements: list[ElementType] + ) -> list[SplitterDocument]: + """Combine elements with common metadata into chunks. + + Args: + elements: HTML element content with associated identifying + info and metadata + """ + aggregated_chunks: list[ElementType] = [] + + for element in elements: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == element["metadata"] + ): + # If the last element in the aggregated list + # has the same metadata as the current element, + # append the current content to the last element's content + aggregated_chunks[-1]["content"] += " \n" + element["content"] + else: + # Otherwise, append the current element to the aggregated list + aggregated_chunks.append(element) + + return [ + SplitterDocument( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in aggregated_chunks + ] + + def split_text_from_url(self, url: str) -> list[SplitterDocument]: + """Split HTML from web URL. + + Args: + url: web URL + """ + r = requests.get(url) + return self.split_text_from_file(BytesIO(r.content)) + + def split_text(self, text: str) -> list[SplitterDocument]: + """Split HTML text string. + + Args: + text: HTML text + """ + return self.split_text_from_file(StringIO(text)) + + def split_text_from_file(self, file: Any) -> list[SplitterDocument]: + """Split HTML file. + + Args: + file: HTML file + """ + try: + from lxml import etree + except ImportError: + raise ImportError( + "Unable to import lxml, run `pip install lxml`." + ) from None + # use lxml library to parse html document and return xml ElementTree + # Explicitly encoding in utf-8 allows non-English + # html files to be processed without garbled characters + parser = etree.HTMLParser(encoding="utf-8") + tree = etree.parse(file, parser) + + # document transformation for "structure-aware" chunking is handled + # with xsl. See comments in html_chunks_with_headers.xslt for more + # detailed information. + xslt_path = ( + pathlib.Path(__file__).parent + / "document_transformers/xsl/html_chunks_with_headers.xslt" + ) + xslt_tree = etree.parse(xslt_path) + transform = etree.XSLT(xslt_tree) + result = transform(tree) + result_dom = etree.fromstring(str(result)) + + # create filter and mapping for header metadata + header_filter = [header[0] for header in self.headers_to_split_on] + header_mapping = dict(self.headers_to_split_on) + + # map xhtml namespace prefix + ns_map = {"h": "http://www.w3.org/1999/xhtml"} + + # build list of elements from DOM + elements = [] + for element in result_dom.findall("*//*", ns_map): + if element.findall("*[@class='headers']") or element.findall( + "*[@class='chunk']" + ): + elements.append( + ElementType( + url=file, + xpath="".join( + [ + node.text + for node in element.findall( + "*[@class='xpath']", ns_map + ) + ] + ), + content="".join( + [ + node.text + for node in element.findall( + "*[@class='chunk']", ns_map + ) + ] + ), + metadata={ + # Add text of specified headers to + # metadata using header mapping. + header_mapping[node.tag]: node.text + for node in filter( + lambda x: x.tag in header_filter, + element.findall( + "*[@class='headers']/*", ns_map + ), + ) + }, + ) + ) + + if not self.return_each_element: + return self.aggregate_elements_to_chunks(elements) + else: + return [ + SplitterDocument( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in elements + ] + + +# should be in newer Python versions (3.11+) +# @dataclass(frozen=True, kw_only=True, slots=True) +@dataclass(frozen=True) +class Tokenizer: + """Tokenizer data class.""" + + chunk_overlap: int + """Overlap in tokens between chunks.""" + tokens_per_chunk: int + """Maximum number of tokens per chunk.""" + decode: Callable[[list[int]], str] + """Function to decode a list of token ids to a string.""" + encode: Callable[[str], list[int]] + """Function to encode a string to a list of token ids.""" + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]: + """Split incoming text and return chunks using tokenizer.""" + splits: list[str] = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + if cur_idx == len(input_ids): + break + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + +class TokenTextSplitter(TextSplitter): + """Splitting text to tokens using model tokenizer.""" + + def __init__( + self, + encoding_name: str = "gpt2", + model: Optional[str] = None, + allowed_special: Literal["all"] | AbstractSet[str] = set(), + disallowed_special: Literal["all"] | Collection[str] = "all", + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for TokenTextSplitter. " + "Please install it with `pip install tiktoken`." + ) from None + + if model is not None: + enc = tiktoken.encoding_for_model(model) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc + self._allowed_special = allowed_special + self._disallowed_special = disallowed_special + + def split_text(self, text: str) -> list[str]: + def _encode(_text: str) -> list[int]: + return self._tokenizer.encode( + _text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=_encode, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class SentenceTransformersTokenTextSplitter(TextSplitter): + """Splitting text to tokens using sentence model tokenizer.""" + + def __init__( + self, + chunk_overlap: int = 50, + model: str = "sentence-transformers/all-mpnet-base-v2", + tokens_per_chunk: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs, chunk_overlap=chunk_overlap) + + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError( + """Could not import sentence_transformer python package. + This is needed in order to for + SentenceTransformersTokenTextSplitter. + Please install it with `pip install sentence-transformers`. + """ + ) from None + + self.model = model + self._model = SentenceTransformer(self.model, trust_remote_code=True) + self.tokenizer = self._model.tokenizer + self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) + + def _initialize_chunk_configuration( + self, *, tokens_per_chunk: Optional[int] + ) -> None: + self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length) + + if tokens_per_chunk is None: + self.tokens_per_chunk = self.maximum_tokens_per_chunk + else: + self.tokens_per_chunk = tokens_per_chunk + + if self.tokens_per_chunk > self.maximum_tokens_per_chunk: + raise ValueError( + f"The token limit of the models '{self.model}'" + f" is: {self.maximum_tokens_per_chunk}." + f" Argument tokens_per_chunk={self.tokens_per_chunk}" + f" > maximum token limit." + ) + + def split_text(self, text: str) -> list[str]: + def encode_strip_start_and_stop_token_ids(text: str) -> list[int]: + return self._encode(text)[1:-1] + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self.tokens_per_chunk, + decode=self.tokenizer.decode, + encode=encode_strip_start_and_stop_token_ids, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + def count_tokens(self, *, text: str) -> int: + return len(self._encode(text)) + + _max_length_equal_32_bit_integer: int = 2**32 + + def _encode(self, text: str) -> list[int]: + token_ids_with_start_and_end_token_ids = self.tokenizer.encode( + text, + max_length=self._max_length_equal_32_bit_integer, + truncation="do_not_truncate", + ) + return token_ids_with_start_and_end_token_ids + + +class Language(str, Enum): + """Enum of the programming languages.""" + + CPP = "cpp" + GO = "go" + JAVA = "java" + KOTLIN = "kotlin" + JS = "js" + TS = "ts" + PHP = "php" + PROTO = "proto" + PYTHON = "python" + RST = "rst" + RUBY = "ruby" + RUST = "rust" + SCALA = "scala" + SWIFT = "swift" + MARKDOWN = "markdown" + LATEX = "latex" + HTML = "html" + SOL = "sol" + CSHARP = "csharp" + COBOL = "cobol" + C = "c" + LUA = "lua" + PERL = "perl" + + +class RecursiveCharacterTextSplitter(TextSplitter): + """Splitting text by recursively look at characters. + + Recursively tries to split by different characters to find one that works. + """ + + def __init__( + self, + separators: Optional[list[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = False, + chunk_size: int = 4000, + chunk_overlap: int = 200, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + keep_separator=keep_separator, + **kwargs, + ) + self._separators = separators or ["\n\n", "\n", " ", ""] + self._is_separator_regex = is_separator_regex + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + def _split_text(self, text: str, separators: list[str]) -> list[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1 :] + break + + _separator = ( + separator if self._is_separator_regex else re.escape(separator) + ) + splits = _split_text_with_regex(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return final_chunks + + def split_text(self, text: str) -> list[str]: + return self._split_text(text, self._separators) + + @classmethod + def from_language( + cls, language: Language, **kwargs: Any + ) -> RecursiveCharacterTextSplitter: + separators = cls.get_separators_for_language(language) + return cls(separators=separators, is_separator_regex=True, **kwargs) + + @staticmethod + def get_separators_for_language(language: Language) -> list[str]: + if language == Language.CPP: + return [ + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.GO: + return [ + # Split along function definitions + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JAVA: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.KOTLIN: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\ninternal ", + "\ncompanion ", + "\nfun ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nwhen ", + "\ncase ", + "\nelse ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JS: + return [ + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.TS: + return [ + "\nenum ", + "\ninterface ", + "\nnamespace ", + "\ntype ", + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PHP: + return [ + # Split along function definitions + "\nfunction ", + # Split along class definitions + "\nclass ", + # Split along control flow statements + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PROTO: + return [ + # Split along message definitions + "\nmessage ", + # Split along service definitions + "\nservice ", + # Split along enum definitions + "\nenum ", + # Split along option definitions + "\noption ", + # Split along import statements + "\nimport ", + # Split along syntax declarations + "\nsyntax ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PYTHON: + return [ + # First, try to split along class definitions + "\nclass ", + "\ndef ", + "\n\tdef ", + # Now split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RST: + return [ + # Split along section titles + "\n=+\n", + "\n-+\n", + "\n\\*+\n", + # Split along directive markers + "\n\n.. *\n\n", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUBY: + return [ + # Split along method definitions + "\ndef ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUST: + return [ + # Split along function definitions + "\nfn ", + "\nconst ", + "\nlet ", + # Split along control flow statements + "\nif ", + "\nwhile ", + "\nfor ", + "\nloop ", + "\nmatch ", + "\nconst ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SCALA: + return [ + # Split along class definitions + "\nclass ", + "\nobject ", + # Split along method definitions + "\ndef ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SWIFT: + return [ + # Split along function definitions + "\nfunc ", + # Split along class definitions + "\nclass ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.MARKDOWN: + return [ + # First, try to split along Markdown headings + # (starting with level 2) + "\n#{1,6} ", + # Note the alternative syntax for headings (below) + # is not handled here + # Heading level 2 + # --------------- + # End of code block + "```\n", + # Horizontal lines + "\n\\*\\*\\*+\n", + "\n---+\n", + "\n___+\n", + # Note that this splitter doesn't handle + # horizontal lines defined + # by *three or more* of ***, ---, or ___, + # but this is not handled + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.LATEX: + return [ + # First, try to split along Latex sections + "\n\\\\chapter{", + "\n\\\\section{", + "\n\\\\subsection{", + "\n\\\\subsubsection{", + # Now split by environments + "\n\\\\begin{enumerate}", + "\n\\\\begin{itemize}", + "\n\\\\begin{description}", + "\n\\\\begin{list}", + "\n\\\\begin{quote}", + "\n\\\\begin{quotation}", + "\n\\\\begin{verse}", + "\n\\\\begin{verbatim}", + # Now split by math environments + "\n\\\begin{align}", + "$$", + "$", + # Now split by the normal type of lines + " ", + "", + ] + elif language == Language.HTML: + return [ + # First, try to split along HTML tags + "<body", + "<div", + "<p", + "<br", + "<li", + "<h1", + "<h2", + "<h3", + "<h4", + "<h5", + "<h6", + "<span", + "<table", + "<tr", + "<td", + "<th", + "<ul", + "<ol", + "<header", + "<footer", + "<nav", + # Head + "<head", + "<style", + "<script", + "<meta", + "<title", + "", + ] + elif language == Language.CSHARP: + return [ + "\ninterface ", + "\nenum ", + "\nimplements ", + "\ndelegate ", + "\nevent ", + # Split along class definitions + "\nclass ", + "\nabstract ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + "\nreturn ", + # Split along control flow statements + "\nif ", + "\ncontinue ", + "\nfor ", + "\nforeach ", + "\nwhile ", + "\nswitch ", + "\nbreak ", + "\ncase ", + "\nelse ", + # Split by exceptions + "\ntry ", + "\nthrow ", + "\nfinally ", + "\ncatch ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SOL: + return [ + # Split along compiler information definitions + "\npragma ", + "\nusing ", + # Split along contract definitions + "\ncontract ", + "\ninterface ", + "\nlibrary ", + # Split along method definitions + "\nconstructor ", + "\ntype ", + "\nfunction ", + "\nevent ", + "\nmodifier ", + "\nerror ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo while ", + "\nassembly ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.COBOL: + return [ + # Split along divisions + "\nIDENTIFICATION DIVISION.", + "\nENVIRONMENT DIVISION.", + "\nDATA DIVISION.", + "\nPROCEDURE DIVISION.", + # Split along sections within DATA DIVISION + "\nWORKING-STORAGE SECTION.", + "\nLINKAGE SECTION.", + "\nFILE SECTION.", + # Split along sections within PROCEDURE DIVISION + "\nINPUT-OUTPUT SECTION.", + # Split along paragraphs and common statements + "\nOPEN ", + "\nCLOSE ", + "\nREAD ", + "\nWRITE ", + "\nIF ", + "\nELSE ", + "\nMOVE ", + "\nPERFORM ", + "\nUNTIL ", + "\nVARYING ", + "\nACCEPT ", + "\nDISPLAY ", + "\nSTOP RUN.", + # Split by the normal type of lines + "\n", + " ", + "", + ] + + else: + raise ValueError( + f"Language {language} is not supported! " + f"Please choose from {list(Language)}" + ) + + +class NLTKTextSplitter(TextSplitter): + """Splitting text using NLTK package.""" + + def __init__( + self, separator: str = "\n\n", language: str = "english", **kwargs: Any + ) -> None: + """Initialize the NLTK splitter.""" + super().__init__(**kwargs) + try: + from nltk.tokenize import sent_tokenize + + self._tokenizer = sent_tokenize + except ImportError: + raise ImportError("""NLTK is not installed, please install it with + `pip install nltk`.""") from None + self._separator = separator + self._language = language + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + splits = self._tokenizer(text, language=self._language) + return self._merge_splits(splits, self._separator) + + +class SpacyTextSplitter(TextSplitter): + """Splitting text using Spacy package. + + Per default, Spacy's `en_core_web_sm` model is used and + its default max_length is 1000000 (it is the length of maximum character + this model takes which can be increased for large files). For a faster, + but potentially less accurate splitting, you can use `pipe='sentencizer'`. + """ + + def __init__( + self, + separator: str = "\n\n", + pipe: str = "en_core_web_sm", + max_length: int = 1_000_000, + **kwargs: Any, + ) -> None: + """Initialize the spacy text splitter.""" + super().__init__(**kwargs) + self._tokenizer = _make_spacy_pipe_for_splitting( + pipe, max_length=max_length + ) + self._separator = separator + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + splits = (s.text for s in self._tokenizer(text).sents) + return self._merge_splits(splits, self._separator) + + +class KonlpyTextSplitter(TextSplitter): + """Splitting text using Konlpy package. + + It is good for splitting Korean text. + """ + + def __init__( + self, + separator: str = "\n\n", + **kwargs: Any, + ) -> None: + """Initialize the Konlpy text splitter.""" + super().__init__(**kwargs) + self._separator = separator + try: + from konlpy.tag import Kkma + except ImportError: + raise ImportError(""" + Konlpy is not installed, please install it with + `pip install konlpy` + """) from None + self.kkma = Kkma() + + def split_text(self, text: str) -> list[str]: + """Split incoming text and return chunks.""" + splits = self.kkma.sentences(text) + return self._merge_splits(splits, self._separator) + + +# For backwards compatibility +class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Python syntax.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a PythonCodeTextSplitter.""" + separators = self.get_separators_for_language(Language.PYTHON) + super().__init__(separators=separators, **kwargs) + + +class MarkdownTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Markdown-formatted headings.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a MarkdownTextSplitter.""" + separators = self.get_separators_for_language(Language.MARKDOWN) + super().__init__(separators=separators, **kwargs) + + +class LatexTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Latex-formatted layout elements.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a LatexTextSplitter.""" + separators = self.get_separators_for_language(Language.LATEX) + super().__init__(separators=separators, **kwargs) + + +class RecursiveJsonSplitter: + def __init__( + self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None + ): + super().__init__() + self.max_chunk_size = max_chunk_size + self.min_chunk_size = ( + min_chunk_size + if min_chunk_size is not None + else max(max_chunk_size - 200, 50) + ) + + @staticmethod + def _json_size(data: dict) -> int: + """Calculate the size of the serialized JSON object.""" + return len(json.dumps(data)) + + @staticmethod + def _set_nested_dict(d: dict, path: list[str], value: Any) -> None: + """Set a value in a nested dictionary based on the given path.""" + for key in path[:-1]: + d = d.setdefault(key, {}) + d[path[-1]] = value + + def _list_to_dict_preprocessing(self, data: Any) -> Any: + if isinstance(data, dict): + # Process each key-value pair in the dictionary + return { + k: self._list_to_dict_preprocessing(v) for k, v in data.items() + } + elif isinstance(data, list): + # Convert the list to a dictionary with index-based keys + return { + str(i): self._list_to_dict_preprocessing(item) + for i, item in enumerate(data) + } + else: + # The item is neither a dict nor a list, return unchanged + return data + + def _json_split( + self, + data: dict[str, Any], + current_path: list[str] | None = None, + chunks: list[dict] | None = None, + ) -> list[dict]: + """Split json into maximum size dictionaries while preserving + structure.""" + if current_path is None: + current_path = [] + if chunks is None: + chunks = [{}] + + if isinstance(data, dict): + for key, value in data.items(): + new_path = current_path + [key] + chunk_size = self._json_size(chunks[-1]) + size = self._json_size({key: value}) + remaining = self.max_chunk_size - chunk_size + + if size < remaining: + # Add item to current chunk + self._set_nested_dict(chunks[-1], new_path, value) + else: + if chunk_size >= self.min_chunk_size: + # Chunk is big enough, start a new chunk + chunks.append({}) + + # Iterate + self._json_split(value, new_path, chunks) + else: + # handle single item + self._set_nested_dict(chunks[-1], current_path, data) + return chunks + + def split_json( + self, + json_data: dict[str, Any], + convert_lists: bool = False, + ) -> list[dict]: + """Splits JSON into a list of JSON chunks.""" + + if convert_lists: + chunks = self._json_split( + self._list_to_dict_preprocessing(json_data) + ) + else: + chunks = self._json_split(json_data) + + # Remove the last chunk if it's empty + if not chunks[-1]: + chunks.pop() + return chunks + + def split_text( + self, json_data: dict[str, Any], convert_lists: bool = False + ) -> list[str]: + """Splits JSON into a list of JSON formatted strings.""" + + chunks = self.split_json( + json_data=json_data, convert_lists=convert_lists + ) + + # Convert to string + return [json.dumps(chunk) for chunk in chunks] + + def create_documents( + self, + texts: list[dict], + convert_lists: bool = False, + metadatas: Optional[list[dict]] = None, + ) -> list[SplitterDocument]: + """Create documents from a list of json objects (dict).""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + for chunk in self.split_text( + json_data=text, convert_lists=convert_lists + ): + metadata = copy.deepcopy(_metadatas[i]) + new_doc = SplitterDocument( + page_content=chunk, metadata=metadata + ) + documents.append(new_doc) + return documents |