diff options
Diffstat (limited to 'R2R/r2r/base/utils/splitter')
-rwxr-xr-x | R2R/r2r/base/utils/splitter/__init__.py | 3 | ||||
-rwxr-xr-x | R2R/r2r/base/utils/splitter/text.py | 1979 |
2 files changed, 1982 insertions, 0 deletions
diff --git a/R2R/r2r/base/utils/splitter/__init__.py b/R2R/r2r/base/utils/splitter/__init__.py new file mode 100755 index 00000000..07a9f554 --- /dev/null +++ b/R2R/r2r/base/utils/splitter/__init__.py @@ -0,0 +1,3 @@ +from .text import RecursiveCharacterTextSplitter + +__all__ = ["RecursiveCharacterTextSplitter"] diff --git a/R2R/r2r/base/utils/splitter/text.py b/R2R/r2r/base/utils/splitter/text.py new file mode 100755 index 00000000..5458310c --- /dev/null +++ b/R2R/r2r/base/utils/splitter/text.py @@ -0,0 +1,1979 @@ +# Source - LangChain +# URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851 +"""**Text Splitters** are classes for splitting text. + + +**Class hierarchy:** + +.. code-block:: + + BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter + RecursiveCharacterTextSplitter --> <name>TextSplitter + +Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter. + + +**Main helpers:** + +.. code-block:: + + Document, Tokenizer, Language, LineType, HeaderType + +""" # noqa: E501 + +from __future__ import annotations + +import copy +import json +import logging +import pathlib +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from io import BytesIO, StringIO +from typing import ( + AbstractSet, + Any, + Callable, + Collection, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + TypedDict, + TypeVar, + Union, + cast, +) + +import requests +from pydantic import BaseModel, Field, PrivateAttr +from typing_extensions import NotRequired + +logger = logging.getLogger(__name__) + +TS = TypeVar("TS", bound="TextSplitter") + + +class BaseSerialized(TypedDict): + """Base class for serialized objects.""" + + lc: int + id: List[str] + name: NotRequired[str] + graph: NotRequired[Dict[str, Any]] + + +class SerializedConstructor(BaseSerialized): + """Serialized constructor.""" + + type: Literal["constructor"] + kwargs: Dict[str, Any] + + +class SerializedSecret(BaseSerialized): + """Serialized secret.""" + + type: Literal["secret"] + + +class SerializedNotImplemented(BaseSerialized): + """Serialized not implemented.""" + + type: Literal["not_implemented"] + repr: Optional[str] + + +def try_neq_default(value: Any, key: str, model: BaseModel) -> bool: + """Try to determine if a value is different from the default. + + Args: + value: The value. + key: The key. + model: The model. + + Returns: + Whether the value is different from the default. + """ + try: + return model.__fields__[key].get_default() != value + except Exception: + return True + + +class Serializable(BaseModel, ABC): + """Serializable base class.""" + + @classmethod + def is_lc_serializable(cls) -> bool: + """Is this class serializable?""" + return False + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object. + + For example, if the class is `langchain.llms.openai.OpenAI`, then the + namespace is ["langchain", "llms", "openai"] + """ + return cls.__module__.split(".") + + @property + def lc_secrets(self) -> Dict[str, str]: + """A map of constructor argument names to secret ids. + + For example, + {"openai_api_key": "OPENAI_API_KEY"} + """ + return dict() + + @property + def lc_attributes(self) -> Dict: + """List of attribute names that should be included in the serialized kwargs. + + These attributes must be accepted by the constructor. + """ + return {} + + @classmethod + def lc_id(cls) -> List[str]: + """A unique identifier for this class for serialization purposes. + + The unique identifier is a list of strings that describes the path + to the object. + """ + return [*cls.get_lc_namespace(), cls.__name__] + + class Config: + extra = "ignore" + + def __repr_args__(self) -> Any: + return [ + (k, v) + for k, v in super().__repr_args__() + if (k not in self.__fields__ or try_neq_default(v, k, self)) + ] + + _lc_kwargs = PrivateAttr(default_factory=dict) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lc_kwargs = kwargs + + def to_json( + self, + ) -> Union[SerializedConstructor, SerializedNotImplemented]: + if not self.is_lc_serializable(): + return self.to_json_not_implemented() + + secrets = dict() + # Get latest values for kwargs if there is an attribute with same name + lc_kwargs = { + k: getattr(self, k, v) + for k, v in self._lc_kwargs.items() + if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore + } + + # Merge the lc_secrets and lc_attributes from every class in the MRO + for cls in [None, *self.__class__.mro()]: + # Once we get to Serializable, we're done + if cls is Serializable: + break + + if cls: + deprecated_attributes = [ + "lc_namespace", + "lc_serializable", + ] + + for attr in deprecated_attributes: + if hasattr(cls, attr): + raise ValueError( + f"Class {self.__class__} has a deprecated " + f"attribute {attr}. Please use the corresponding " + f"classmethod instead." + ) + + # Get a reference to self bound to each class in the MRO + this = cast( + Serializable, self if cls is None else super(cls, self) + ) + + secrets.update(this.lc_secrets) + # Now also add the aliases for the secrets + # This ensures known secret aliases are hidden. + # Note: this does NOT hide any other extra kwargs + # that are not present in the fields. + for key in list(secrets): + value = secrets[key] + if key in this.__fields__: + secrets[this.__fields__[key].alias] = value + lc_kwargs.update(this.lc_attributes) + + # include all secrets, even if not specified in kwargs + # as these secrets may be passed as an environment variable instead + for key in secrets.keys(): + secret_value = getattr(self, key, None) or lc_kwargs.get(key) + if secret_value is not None: + lc_kwargs.update({key: secret_value}) + + return { + "lc": 1, + "type": "constructor", + "id": self.lc_id(), + "kwargs": ( + lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets) + ), + } + + def to_json_not_implemented(self) -> SerializedNotImplemented: + return to_json_not_implemented(self) + + +def _replace_secrets( + root: Dict[Any, Any], secrets_map: Dict[str, str] +) -> Dict[Any, Any]: + result = root.copy() + for path, secret_id in secrets_map.items(): + [*parts, last] = path.split(".") + current = result + for part in parts: + if part not in current: + break + current[part] = current[part].copy() + current = current[part] + if last in current: + current[last] = { + "lc": 1, + "type": "secret", + "id": [secret_id], + } + return result + + +def to_json_not_implemented(obj: object) -> SerializedNotImplemented: + """Serialize a "not implemented" object. + + Args: + obj: object to serialize + + Returns: + SerializedNotImplemented + """ + _id: List[str] = [] + try: + if hasattr(obj, "__name__"): + _id = [*obj.__module__.split("."), obj.__name__] + elif hasattr(obj, "__class__"): + _id = [ + *obj.__class__.__module__.split("."), + obj.__class__.__name__, + ] + except Exception: + pass + + result: SerializedNotImplemented = { + "lc": 1, + "type": "not_implemented", + "id": _id, + "repr": None, + } + try: + result["repr"] = repr(obj) + except Exception: + pass + return result + + +class Document(Serializable): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + """String text.""" + metadata: dict = Field(default_factory=dict) + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + type: Literal["Document"] = "Document" + + def __init__(self, page_content: str, **kwargs: Any) -> None: + """Pass page_content in as positional or named arg.""" + super().__init__(page_content=page_content, **kwargs) + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "schema", "document"] + + +class BaseDocumentTransformer(ABC): + """Abstract base class for document transformation systems. + + A document transformation system takes a sequence of Documents and returns a + sequence of transformed Documents. + + Example: + .. code-block:: python + + class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + embeddings: Embeddings + similarity_fn: Callable = cosine_similarity + similarity_threshold: float = 0.95 + + class Config: + arbitrary_types_allowed = True + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + stateful_documents = get_stateful_documents(documents) + embedded_documents = _get_embeddings_from_stateful_docs( + self.embeddings, stateful_documents + ) + included_idxs = _filter_similar_embeddings( + embedded_documents, self.similarity_fn, self.similarity_threshold + ) + return [stateful_documents[i] for i in sorted(included_idxs)] + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + raise NotImplementedError + + """ # noqa: E501 + + @abstractmethod + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + + async def atransform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Asynchronously transform a list of documents. + + Args: + documents: A sequence of Documents to be transformed. + + Returns: + A list of transformed Documents. + """ + raise NotImplementedError("This method is not implemented.") + # return await langchain_core.runnables.config.run_in_executor( + # None, self.transform_documents, documents, **kwargs + # ) + + +def _make_spacy_pipe_for_splitting( + pipe: str, *, max_length: int = 1_000_000 +) -> Any: # avoid importing spacy + try: + import spacy + except ImportError: + raise ImportError( + "Spacy is not installed, please install it with `pip install spacy`." + ) + if pipe == "sentencizer": + from spacy.lang.en import English + + sentencizer = English() + sentencizer.add_pipe("sentencizer") + else: + sentencizer = spacy.load(pipe, exclude=["ner", "tagger"]) + sentencizer.max_length = max_length + return sentencizer + + +def _split_text_with_regex( + text: str, separator: str, keep_separator: bool +) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = [ + _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2) + ] + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +class TextSplitter(BaseDocumentTransformer, ABC): + """Interface for splitting text into chunks.""" + + def __init__( + self, + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, + strip_whitespace: bool = True, + ) -> None: + """Create a new TextSplitter. + + Args: + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + length_function: Function that measures the length of given chunks + keep_separator: Whether to keep the separator in the chunks + add_start_index: If `True`, includes chunk's start index in metadata + strip_whitespace: If `True`, strips whitespace from the start and end of + every document + """ + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._add_start_index = add_start_index + self._strip_whitespace = strip_whitespace + + @abstractmethod + def split_text(self, text: str) -> List[str]: + """Split text into multiple components.""" + + def create_documents( + self, texts: List[str], metadatas: Optional[List[dict]] = None + ) -> List[Document]: + """Create documents from a list of texts.""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + index = 0 + previous_chunk_len = 0 + for chunk in self.split_text(text): + metadata = copy.deepcopy(_metadatas[i]) + if self._add_start_index: + offset = index + previous_chunk_len - self._chunk_overlap + index = text.find(chunk, max(0, offset)) + metadata["start_index"] = index + previous_chunk_len = len(chunk) + new_doc = Document(page_content=chunk, metadata=metadata) + documents.append(new_doc) + return documents + + def split_documents(self, documents: Iterable[Document]) -> List[Document]: + """Split documents.""" + texts, metadatas = [], [] + for doc in documents: + texts.append(doc.page_content) + metadatas.append(doc.metadata) + return self.create_documents(texts, metadatas=metadatas) + + def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: + text = separator.join(docs) + if self._strip_whitespace: + text = text.strip() + if text == "": + return None + else: + return text + + def _merge_splits( + self, splits: Iterable[str], separator: str + ) -> List[str]: + # We now want to combine these smaller pieces into medium size + # chunks to send to the LLM. + separator_len = self._length_function(separator) + + docs = [] + current_doc: List[str] = [] + total = 0 + for d in splits: + _len = self._length_function(d) + if ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + ): + if total > self._chunk_size: + logger.warning( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap or ( + total + + _len + + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + and total > 0 + ): + total -= self._length_function(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) + current_doc = current_doc[1:] + current_doc.append(d) + total += _len + (separator_len if len(current_doc) > 1 else 0) + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + return docs + + @classmethod + def from_huggingface_tokenizer( + cls, tokenizer: Any, **kwargs: Any + ) -> TextSplitter: + """Text splitter that uses HuggingFace tokenizer to count length.""" + try: + from transformers import PreTrainedTokenizerBase + + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + "Tokenizer received was not an instance of PreTrainedTokenizerBase" + ) + + def _huggingface_tokenizer_length(text: str) -> int: + return len(tokenizer.encode(text)) + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please install it with `pip install transformers`." + ) + return cls(length_function=_huggingface_tokenizer_length, **kwargs) + + @classmethod + def from_tiktoken_encoder( + cls: Type[TS], + encoding_name: str = "gpt2", + model: Optional[str] = None, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ) -> TS: + """Text splitter that uses tiktoken encoder to count length.""" + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to calculate max_tokens_for_prompt. " + "Please install it with `pip install tiktoken`." + ) + + if model is not None: + enc = tiktoken.encoding_for_model(model) + else: + enc = tiktoken.get_encoding(encoding_name) + + def _tiktoken_encoder(text: str) -> int: + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + + if issubclass(cls, TokenTextSplitter): + extra_kwargs = { + "encoding_name": encoding_name, + "model": model, + "allowed_special": allowed_special, + "disallowed_special": disallowed_special, + } + kwargs = {**kwargs, **extra_kwargs} + + return cls(length_function=_tiktoken_encoder, **kwargs) + + def transform_documents( + self, documents: Sequence[Document], **kwargs: Any + ) -> Sequence[Document]: + """Transform sequence of documents by splitting them.""" + return self.split_documents(list(documents)) + + +class CharacterTextSplitter(TextSplitter): + """Splitting text that looks at characters.""" + + def __init__( + self, + separator: str = "\n\n", + is_separator_regex: bool = False, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separator = separator + self._is_separator_regex = is_separator_regex + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + separator = ( + self._separator + if self._is_separator_regex + else re.escape(self._separator) + ) + splits = _split_text_with_regex(text, separator, self._keep_separator) + _separator = "" if self._keep_separator else self._separator + return self._merge_splits(splits, _separator) + + +class LineType(TypedDict): + """Line type as typed dict.""" + + metadata: Dict[str, str] + content: str + + +class HeaderType(TypedDict): + """Header type as typed dict.""" + + level: int + name: str + data: str + + +class MarkdownHeaderTextSplitter: + """Splitting markdown files based on specified headers.""" + + def __init__( + self, + headers_to_split_on: List[Tuple[str, str]], + return_each_line: bool = False, + strip_headers: bool = True, + ): + """Create a new MarkdownHeaderTextSplitter. + + Args: + headers_to_split_on: Headers we want to track + return_each_line: Return each line w/ associated headers + strip_headers: Strip split headers from the content of the chunk + """ + # Output line-by-line or aggregated into chunks w/ common headers + self.return_each_line = return_each_line + # Given the headers we want to split on, + # (e.g., "#, ##, etc") order by length + self.headers_to_split_on = sorted( + headers_to_split_on, key=lambda split: len(split[0]), reverse=True + ) + # Strip headers split headers from the content of the chunk + self.strip_headers = strip_headers + + def aggregate_lines_to_chunks( + self, lines: List[LineType] + ) -> List[Document]: + """Combine lines with common metadata into chunks + Args: + lines: Line of text / associated header metadata + """ + aggregated_chunks: List[LineType] = [] + + for line in lines: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == line["metadata"] + ): + # If the last line in the aggregated list + # has the same metadata as the current line, + # append the current content to the last lines's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + elif ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] != line["metadata"] + # may be issues if other metadata is present + and len(aggregated_chunks[-1]["metadata"]) + < len(line["metadata"]) + and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#" + and not self.strip_headers + ): + # If the last line in the aggregated list + # has different metadata as the current line, + # and has shallower header level than the current line, + # and the last line is a header, + # and we are not stripping headers, + # append the current content to the last line's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + # and update the last line's metadata + aggregated_chunks[-1]["metadata"] = line["metadata"] + else: + # Otherwise, append the current line to the aggregated list + aggregated_chunks.append(line) + + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in aggregated_chunks + ] + + def split_text(self, text: str) -> List[Document]: + """Split markdown file + Args: + text: Markdown file""" + + # Split the input text by newline character ("\n"). + lines = text.split("\n") + # Final output + lines_with_metadata: List[LineType] = [] + # Content and metadata of the chunk currently being processed + current_content: List[str] = [] + current_metadata: Dict[str, str] = {} + # Keep track of the nested header structure + # header_stack: List[Dict[str, Union[int, str]]] = [] + header_stack: List[HeaderType] = [] + initial_metadata: Dict[str, str] = {} + + in_code_block = False + opening_fence = "" + + for line in lines: + stripped_line = line.strip() + + if not in_code_block: + # Exclude inline code spans + if ( + stripped_line.startswith("```") + and stripped_line.count("```") == 1 + ): + in_code_block = True + opening_fence = "```" + elif stripped_line.startswith("~~~"): + in_code_block = True + opening_fence = "~~~" + else: + if stripped_line.startswith(opening_fence): + in_code_block = False + opening_fence = "" + + if in_code_block: + current_content.append(stripped_line) + continue + + # Check each line against each of the header types (e.g., #, ##) + for sep, name in self.headers_to_split_on: + # Check if line starts with a header that we intend to split on + if stripped_line.startswith(sep) and ( + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) + or stripped_line[len(sep)] == " " + ): + # Ensure we are tracking the header as metadata + if name is not None: + # Get the current header level + current_header_level = sep.count("#") + + # Pop out headers of lower or same level from the stack + while ( + header_stack + and header_stack[-1]["level"] + >= current_header_level + ): + # We have encountered a new header + # at the same or higher level + popped_header = header_stack.pop() + # Clear the metadata for the + # popped header in initial_metadata + if popped_header["name"] in initial_metadata: + initial_metadata.pop(popped_header["name"]) + + # Push the current header to the stack + header: HeaderType = { + "level": current_header_level, + "name": name, + "data": stripped_line[len(sep) :].strip(), + } + header_stack.append(header) + # Update initial_metadata with the current header + initial_metadata[name] = header["data"] + + # Add the previous line to the lines_with_metadata + # only if current_content is not empty + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + if not self.strip_headers: + current_content.append(stripped_line) + + break + else: + if stripped_line: + current_content.append(stripped_line) + elif current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata.copy(), + } + ) + current_content.clear() + + current_metadata = initial_metadata.copy() + + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "metadata": current_metadata, + } + ) + + # lines_with_metadata has each line with associated header metadata + # aggregate these into chunks based on common metadata + if not self.return_each_line: + return self.aggregate_lines_to_chunks(lines_with_metadata) + else: + return [ + Document( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in lines_with_metadata + ] + + +class ElementType(TypedDict): + """Element type as typed dict.""" + + url: str + xpath: str + content: str + metadata: Dict[str, str] + + +class HTMLHeaderTextSplitter: + """ + Splitting HTML files based on specified headers. + Requires lxml package. + """ + + def __init__( + self, + headers_to_split_on: List[Tuple[str, str]], + return_each_element: bool = False, + ): + """Create a new HTMLHeaderTextSplitter. + + Args: + headers_to_split_on: list of tuples of headers we want to track mapped to + (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4, + h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)]. + return_each_element: Return each element w/ associated headers. + """ + # Output element-by-element or aggregated into chunks w/ common headers + self.return_each_element = return_each_element + self.headers_to_split_on = sorted(headers_to_split_on) + + def aggregate_elements_to_chunks( + self, elements: List[ElementType] + ) -> List[Document]: + """Combine elements with common metadata into chunks + + Args: + elements: HTML element content with associated identifying info and metadata + """ + aggregated_chunks: List[ElementType] = [] + + for element in elements: + if ( + aggregated_chunks + and aggregated_chunks[-1]["metadata"] == element["metadata"] + ): + # If the last element in the aggregated list + # has the same metadata as the current element, + # append the current content to the last element's content + aggregated_chunks[-1]["content"] += " \n" + element["content"] + else: + # Otherwise, append the current element to the aggregated list + aggregated_chunks.append(element) + + return [ + Document(page_content=chunk["content"], metadata=chunk["metadata"]) + for chunk in aggregated_chunks + ] + + def split_text_from_url(self, url: str) -> List[Document]: + """Split HTML from web URL + + Args: + url: web URL + """ + r = requests.get(url) + return self.split_text_from_file(BytesIO(r.content)) + + def split_text(self, text: str) -> List[Document]: + """Split HTML text string + + Args: + text: HTML text + """ + return self.split_text_from_file(StringIO(text)) + + def split_text_from_file(self, file: Any) -> List[Document]: + """Split HTML file + + Args: + file: HTML file + """ + try: + from lxml import etree + except ImportError as e: + raise ImportError( + "Unable to import lxml, please install with `pip install lxml`." + ) from e + # use lxml library to parse html document and return xml ElementTree + # Explicitly encoding in utf-8 allows non-English + # html files to be processed without garbled characters + parser = etree.HTMLParser(encoding="utf-8") + tree = etree.parse(file, parser) + + # document transformation for "structure-aware" chunking is handled with xsl. + # see comments in html_chunks_with_headers.xslt for more detailed information. + xslt_path = ( + pathlib.Path(__file__).parent + / "document_transformers/xsl/html_chunks_with_headers.xslt" + ) + xslt_tree = etree.parse(xslt_path) + transform = etree.XSLT(xslt_tree) + result = transform(tree) + result_dom = etree.fromstring(str(result)) + + # create filter and mapping for header metadata + header_filter = [header[0] for header in self.headers_to_split_on] + header_mapping = dict(self.headers_to_split_on) + + # map xhtml namespace prefix + ns_map = {"h": "http://www.w3.org/1999/xhtml"} + + # build list of elements from DOM + elements = [] + for element in result_dom.findall("*//*", ns_map): + if element.findall("*[@class='headers']") or element.findall( + "*[@class='chunk']" + ): + elements.append( + ElementType( + url=file, + xpath="".join( + [ + node.text + for node in element.findall( + "*[@class='xpath']", ns_map + ) + ] + ), + content="".join( + [ + node.text + for node in element.findall( + "*[@class='chunk']", ns_map + ) + ] + ), + metadata={ + # Add text of specified headers to metadata using header + # mapping. + header_mapping[node.tag]: node.text + for node in filter( + lambda x: x.tag in header_filter, + element.findall( + "*[@class='headers']/*", ns_map + ), + ) + }, + ) + ) + + if not self.return_each_element: + return self.aggregate_elements_to_chunks(elements) + else: + return [ + Document( + page_content=chunk["content"], metadata=chunk["metadata"] + ) + for chunk in elements + ] + + +# should be in newer Python versions (3.10+) +# @dataclass(frozen=True, kw_only=True, slots=True) +@dataclass(frozen=True) +class Tokenizer: + """Tokenizer data class.""" + + chunk_overlap: int + """Overlap in tokens between chunks""" + tokens_per_chunk: int + """Maximum number of tokens per chunk""" + decode: Callable[[List[int]], str] + """ Function to decode a list of token ids to a string""" + encode: Callable[[str], List[int]] + """ Function to encode a string to a list of token ids""" + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: + """Split incoming text and return chunks using tokenizer.""" + splits: List[str] = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + if cur_idx == len(input_ids): + break + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + +class TokenTextSplitter(TextSplitter): + """Splitting text to tokens using model tokenizer.""" + + def __init__( + self, + encoding_name: str = "gpt2", + model: Optional[str] = None, + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + try: + import tiktoken + except ImportError: + raise ImportError( + "Could not import tiktoken python package. " + "This is needed in order to for TokenTextSplitter. " + "Please install it with `pip install tiktoken`." + ) + + if model is not None: + enc = tiktoken.encoding_for_model(model) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc + self._allowed_special = allowed_special + self._disallowed_special = disallowed_special + + def split_text(self, text: str) -> List[str]: + def _encode(_text: str) -> List[int]: + return self._tokenizer.encode( + _text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=_encode, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class SentenceTransformersTokenTextSplitter(TextSplitter): + """Splitting text to tokens using sentence model tokenizer.""" + + def __init__( + self, + chunk_overlap: int = 50, + model: str = "sentence-transformers/all-mpnet-base-v2", + tokens_per_chunk: Optional[int] = None, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs, chunk_overlap=chunk_overlap) + + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError( + "Could not import sentence_transformer python package. " + "This is needed in order to for SentenceTransformersTokenTextSplitter. " + "Please install it with `pip install sentence-transformers`." + ) + + self.model = model + self._model = SentenceTransformer(self.model, trust_remote_code=True) + self.tokenizer = self._model.tokenizer + self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) + + def _initialize_chunk_configuration( + self, *, tokens_per_chunk: Optional[int] + ) -> None: + self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length) + + if tokens_per_chunk is None: + self.tokens_per_chunk = self.maximum_tokens_per_chunk + else: + self.tokens_per_chunk = tokens_per_chunk + + if self.tokens_per_chunk > self.maximum_tokens_per_chunk: + raise ValueError( + f"The token limit of the models '{self.model}'" + f" is: {self.maximum_tokens_per_chunk}." + f" Argument tokens_per_chunk={self.tokens_per_chunk}" + f" > maximum token limit." + ) + + def split_text(self, text: str) -> List[str]: + def encode_strip_start_and_stop_token_ids(text: str) -> List[int]: + return self._encode(text)[1:-1] + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self.tokens_per_chunk, + decode=self.tokenizer.decode, + encode=encode_strip_start_and_stop_token_ids, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + def count_tokens(self, *, text: str) -> int: + return len(self._encode(text)) + + _max_length_equal_32_bit_integer: int = 2**32 + + def _encode(self, text: str) -> List[int]: + token_ids_with_start_and_end_token_ids = self.tokenizer.encode( + text, + max_length=self._max_length_equal_32_bit_integer, + truncation="do_not_truncate", + ) + return token_ids_with_start_and_end_token_ids + + +class Language(str, Enum): + """Enum of the programming languages.""" + + CPP = "cpp" + GO = "go" + JAVA = "java" + KOTLIN = "kotlin" + JS = "js" + TS = "ts" + PHP = "php" + PROTO = "proto" + PYTHON = "python" + RST = "rst" + RUBY = "ruby" + RUST = "rust" + SCALA = "scala" + SWIFT = "swift" + MARKDOWN = "markdown" + LATEX = "latex" + HTML = "html" + SOL = "sol" + CSHARP = "csharp" + COBOL = "cobol" + C = "c" + LUA = "lua" + PERL = "perl" + + +class RecursiveCharacterTextSplitter(TextSplitter): + """Splitting text by recursively look at characters. + + Recursively tries to split by different characters to find one + that works. + """ + + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = False, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or ["\n\n", "\n", " ", ""] + self._is_separator_regex = is_separator_regex + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1 :] + break + + _separator = ( + separator if self._is_separator_regex else re.escape(separator) + ) + splits = _split_text_with_regex(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return final_chunks + + def split_text(self, text: str) -> List[str]: + return self._split_text(text, self._separators) + + @classmethod + def from_language( + cls, language: Language, **kwargs: Any + ) -> RecursiveCharacterTextSplitter: + separators = cls.get_separators_for_language(language) + return cls(separators=separators, is_separator_regex=True, **kwargs) + + @staticmethod + def get_separators_for_language(language: Language) -> List[str]: + if language == Language.CPP: + return [ + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nvoid ", + "\nint ", + "\nfloat ", + "\ndouble ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.GO: + return [ + # Split along function definitions + "\nfunc ", + "\nvar ", + "\nconst ", + "\ntype ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JAVA: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.KOTLIN: + return [ + # Split along class definitions + "\nclass ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\ninternal ", + "\ncompanion ", + "\nfun ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nwhen ", + "\ncase ", + "\nelse ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.JS: + return [ + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.TS: + return [ + "\nenum ", + "\ninterface ", + "\nnamespace ", + "\ntype ", + # Split along class definitions + "\nclass ", + # Split along function definitions + "\nfunction ", + "\nconst ", + "\nlet ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nswitch ", + "\ncase ", + "\ndefault ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PHP: + return [ + # Split along function definitions + "\nfunction ", + # Split along class definitions + "\nclass ", + # Split along control flow statements + "\nif ", + "\nforeach ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PROTO: + return [ + # Split along message definitions + "\nmessage ", + # Split along service definitions + "\nservice ", + # Split along enum definitions + "\nenum ", + # Split along option definitions + "\noption ", + # Split along import statements + "\nimport ", + # Split along syntax declarations + "\nsyntax ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.PYTHON: + return [ + # First, try to split along class definitions + "\nclass ", + "\ndef ", + "\n\tdef ", + # Now split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RST: + return [ + # Split along section titles + "\n=+\n", + "\n-+\n", + "\n\\*+\n", + # Split along directive markers + "\n\n.. *\n\n", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUBY: + return [ + # Split along method definitions + "\ndef ", + "\nclass ", + # Split along control flow statements + "\nif ", + "\nunless ", + "\nwhile ", + "\nfor ", + "\ndo ", + "\nbegin ", + "\nrescue ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.RUST: + return [ + # Split along function definitions + "\nfn ", + "\nconst ", + "\nlet ", + # Split along control flow statements + "\nif ", + "\nwhile ", + "\nfor ", + "\nloop ", + "\nmatch ", + "\nconst ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SCALA: + return [ + # Split along class definitions + "\nclass ", + "\nobject ", + # Split along method definitions + "\ndef ", + "\nval ", + "\nvar ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\nmatch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SWIFT: + return [ + # Split along function definitions + "\nfunc ", + # Split along class definitions + "\nclass ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo ", + "\nswitch ", + "\ncase ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.MARKDOWN: + return [ + # First, try to split along Markdown headings (starting with level 2) + "\n#{1,6} ", + # Note the alternative syntax for headings (below) is not handled here + # Heading level 2 + # --------------- + # End of code block + "```\n", + # Horizontal lines + "\n\\*\\*\\*+\n", + "\n---+\n", + "\n___+\n", + # Note that this splitter doesn't handle horizontal lines defined + # by *three or more* of ***, ---, or ___, but this is not handled + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.LATEX: + return [ + # First, try to split along Latex sections + "\n\\\\chapter{", + "\n\\\\section{", + "\n\\\\subsection{", + "\n\\\\subsubsection{", + # Now split by environments + "\n\\\\begin{enumerate}", + "\n\\\\begin{itemize}", + "\n\\\\begin{description}", + "\n\\\\begin{list}", + "\n\\\\begin{quote}", + "\n\\\\begin{quotation}", + "\n\\\\begin{verse}", + "\n\\\\begin{verbatim}", + # Now split by math environments + "\n\\\begin{align}", + "$$", + "$", + # Now split by the normal type of lines + " ", + "", + ] + elif language == Language.HTML: + return [ + # First, try to split along HTML tags + "<body", + "<div", + "<p", + "<br", + "<li", + "<h1", + "<h2", + "<h3", + "<h4", + "<h5", + "<h6", + "<span", + "<table", + "<tr", + "<td", + "<th", + "<ul", + "<ol", + "<header", + "<footer", + "<nav", + # Head + "<head", + "<style", + "<script", + "<meta", + "<title", + "", + ] + elif language == Language.CSHARP: + return [ + "\ninterface ", + "\nenum ", + "\nimplements ", + "\ndelegate ", + "\nevent ", + # Split along class definitions + "\nclass ", + "\nabstract ", + # Split along method definitions + "\npublic ", + "\nprotected ", + "\nprivate ", + "\nstatic ", + "\nreturn ", + # Split along control flow statements + "\nif ", + "\ncontinue ", + "\nfor ", + "\nforeach ", + "\nwhile ", + "\nswitch ", + "\nbreak ", + "\ncase ", + "\nelse ", + # Split by exceptions + "\ntry ", + "\nthrow ", + "\nfinally ", + "\ncatch ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.SOL: + return [ + # Split along compiler information definitions + "\npragma ", + "\nusing ", + # Split along contract definitions + "\ncontract ", + "\ninterface ", + "\nlibrary ", + # Split along method definitions + "\nconstructor ", + "\ntype ", + "\nfunction ", + "\nevent ", + "\nmodifier ", + "\nerror ", + "\nstruct ", + "\nenum ", + # Split along control flow statements + "\nif ", + "\nfor ", + "\nwhile ", + "\ndo while ", + "\nassembly ", + # Split by the normal type of lines + "\n\n", + "\n", + " ", + "", + ] + elif language == Language.COBOL: + return [ + # Split along divisions + "\nIDENTIFICATION DIVISION.", + "\nENVIRONMENT DIVISION.", + "\nDATA DIVISION.", + "\nPROCEDURE DIVISION.", + # Split along sections within DATA DIVISION + "\nWORKING-STORAGE SECTION.", + "\nLINKAGE SECTION.", + "\nFILE SECTION.", + # Split along sections within PROCEDURE DIVISION + "\nINPUT-OUTPUT SECTION.", + # Split along paragraphs and common statements + "\nOPEN ", + "\nCLOSE ", + "\nREAD ", + "\nWRITE ", + "\nIF ", + "\nELSE ", + "\nMOVE ", + "\nPERFORM ", + "\nUNTIL ", + "\nVARYING ", + "\nACCEPT ", + "\nDISPLAY ", + "\nSTOP RUN.", + # Split by the normal type of lines + "\n", + " ", + "", + ] + + else: + raise ValueError( + f"Language {language} is not supported! " + f"Please choose from {list(Language)}" + ) + + +class NLTKTextSplitter(TextSplitter): + """Splitting text using NLTK package.""" + + def __init__( + self, separator: str = "\n\n", language: str = "english", **kwargs: Any + ) -> None: + """Initialize the NLTK splitter.""" + super().__init__(**kwargs) + try: + from nltk.tokenize import sent_tokenize + + self._tokenizer = sent_tokenize + except ImportError: + raise ImportError( + "NLTK is not installed, please install it with `pip install nltk`." + ) + self._separator = separator + self._language = language + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + splits = self._tokenizer(text, language=self._language) + return self._merge_splits(splits, self._separator) + + +class SpacyTextSplitter(TextSplitter): + """Splitting text using Spacy package. + + + Per default, Spacy's `en_core_web_sm` model is used and + its default max_length is 1000000 (it is the length of maximum character + this model takes which can be increased for large files). For a faster, but + potentially less accurate splitting, you can use `pipe='sentencizer'`. + """ + + def __init__( + self, + separator: str = "\n\n", + pipe: str = "en_core_web_sm", + max_length: int = 1_000_000, + **kwargs: Any, + ) -> None: + """Initialize the spacy text splitter.""" + super().__init__(**kwargs) + self._tokenizer = _make_spacy_pipe_for_splitting( + pipe, max_length=max_length + ) + self._separator = separator + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + splits = (s.text for s in self._tokenizer(text).sents) + return self._merge_splits(splits, self._separator) + + +class KonlpyTextSplitter(TextSplitter): + """Splitting text using Konlpy package. + + It is good for splitting Korean text. + """ + + def __init__( + self, + separator: str = "\n\n", + **kwargs: Any, + ) -> None: + """Initialize the Konlpy text splitter.""" + super().__init__(**kwargs) + self._separator = separator + try: + from konlpy.tag import Kkma + except ImportError: + raise ImportError( + """ + Konlpy is not installed, please install it with + `pip install konlpy` + """ + ) + self.kkma = Kkma() + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + splits = self.kkma.sentences(text) + return self._merge_splits(splits, self._separator) + + +# For backwards compatibility +class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Python syntax.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a PythonCodeTextSplitter.""" + separators = self.get_separators_for_language(Language.PYTHON) + super().__init__(separators=separators, **kwargs) + + +class MarkdownTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Markdown-formatted headings.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a MarkdownTextSplitter.""" + separators = self.get_separators_for_language(Language.MARKDOWN) + super().__init__(separators=separators, **kwargs) + + +class LatexTextSplitter(RecursiveCharacterTextSplitter): + """Attempts to split the text along Latex-formatted layout elements.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a LatexTextSplitter.""" + separators = self.get_separators_for_language(Language.LATEX) + super().__init__(separators=separators, **kwargs) + + +class RecursiveJsonSplitter: + def __init__( + self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None + ): + super().__init__() + self.max_chunk_size = max_chunk_size + self.min_chunk_size = ( + min_chunk_size + if min_chunk_size is not None + else max(max_chunk_size - 200, 50) + ) + + @staticmethod + def _json_size(data: Dict) -> int: + """Calculate the size of the serialized JSON object.""" + return len(json.dumps(data)) + + @staticmethod + def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None: + """Set a value in a nested dictionary based on the given path.""" + for key in path[:-1]: + d = d.setdefault(key, {}) + d[path[-1]] = value + + def _list_to_dict_preprocessing(self, data: Any) -> Any: + if isinstance(data, dict): + # Process each key-value pair in the dictionary + return { + k: self._list_to_dict_preprocessing(v) for k, v in data.items() + } + elif isinstance(data, list): + # Convert the list to a dictionary with index-based keys + return { + str(i): self._list_to_dict_preprocessing(item) + for i, item in enumerate(data) + } + else: + # Base case: the item is neither a dict nor a list, so return it unchanged + return data + + def _json_split( + self, + data: Dict[str, Any], + current_path: List[str] = [], + chunks: List[Dict] = [{}], + ) -> List[Dict]: + """ + Split json into maximum size dictionaries while preserving structure. + """ + if isinstance(data, dict): + for key, value in data.items(): + new_path = current_path + [key] + chunk_size = self._json_size(chunks[-1]) + size = self._json_size({key: value}) + remaining = self.max_chunk_size - chunk_size + + if size < remaining: + # Add item to current chunk + self._set_nested_dict(chunks[-1], new_path, value) + else: + if chunk_size >= self.min_chunk_size: + # Chunk is big enough, start a new chunk + chunks.append({}) + + # Iterate + self._json_split(value, new_path, chunks) + else: + # handle single item + self._set_nested_dict(chunks[-1], current_path, data) + return chunks + + def split_json( + self, + json_data: Dict[str, Any], + convert_lists: bool = False, + ) -> List[Dict]: + """Splits JSON into a list of JSON chunks""" + + if convert_lists: + chunks = self._json_split( + self._list_to_dict_preprocessing(json_data) + ) + else: + chunks = self._json_split(json_data) + + # Remove the last chunk if it's empty + if not chunks[-1]: + chunks.pop() + return chunks + + def split_text( + self, json_data: Dict[str, Any], convert_lists: bool = False + ) -> List[str]: + """Splits JSON into a list of JSON formatted strings""" + + chunks = self.split_json( + json_data=json_data, convert_lists=convert_lists + ) + + # Convert to string + return [json.dumps(chunk) for chunk in chunks] + + def create_documents( + self, + texts: List[Dict], + convert_lists: bool = False, + metadatas: Optional[List[dict]] = None, + ) -> List[Document]: + """Create documents from a list of json objects (Dict).""" + _metadatas = metadatas or [{}] * len(texts) + documents = [] + for i, text in enumerate(texts): + for chunk in self.split_text( + json_data=text, convert_lists=convert_lists + ): + metadata = copy.deepcopy(_metadatas[i]) + new_doc = Document(page_content=chunk, metadata=metadata) + documents.append(new_doc) + return documents |