aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/utils/splitter
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/utils/splitter')
-rwxr-xr-xR2R/r2r/base/utils/splitter/__init__.py3
-rwxr-xr-xR2R/r2r/base/utils/splitter/text.py1979
2 files changed, 1982 insertions, 0 deletions
diff --git a/R2R/r2r/base/utils/splitter/__init__.py b/R2R/r2r/base/utils/splitter/__init__.py
new file mode 100755
index 00000000..07a9f554
--- /dev/null
+++ b/R2R/r2r/base/utils/splitter/__init__.py
@@ -0,0 +1,3 @@
+from .text import RecursiveCharacterTextSplitter
+
+__all__ = ["RecursiveCharacterTextSplitter"]
diff --git a/R2R/r2r/base/utils/splitter/text.py b/R2R/r2r/base/utils/splitter/text.py
new file mode 100755
index 00000000..5458310c
--- /dev/null
+++ b/R2R/r2r/base/utils/splitter/text.py
@@ -0,0 +1,1979 @@
+# Source - LangChain
+# URL: https://github.com/langchain-ai/langchain/blob/6a5b084704afa22ca02f78d0464f35aed75d1ff2/libs/langchain/langchain/text_splitter.py#L851
+"""**Text Splitters** are classes for splitting text.
+
+
+**Class hierarchy:**
+
+.. code-block::
+
+ BaseDocumentTransformer --> TextSplitter --> <name>TextSplitter # Example: CharacterTextSplitter
+ RecursiveCharacterTextSplitter --> <name>TextSplitter
+
+Note: **MarkdownHeaderTextSplitter** and **HTMLHeaderTextSplitter do not derive from TextSplitter.
+
+
+**Main helpers:**
+
+.. code-block::
+
+ Document, Tokenizer, Language, LineType, HeaderType
+
+""" # noqa: E501
+
+from __future__ import annotations
+
+import copy
+import json
+import logging
+import pathlib
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum
+from io import BytesIO, StringIO
+from typing import (
+ AbstractSet,
+ Any,
+ Callable,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypedDict,
+ TypeVar,
+ Union,
+ cast,
+)
+
+import requests
+from pydantic import BaseModel, Field, PrivateAttr
+from typing_extensions import NotRequired
+
+logger = logging.getLogger(__name__)
+
+TS = TypeVar("TS", bound="TextSplitter")
+
+
+class BaseSerialized(TypedDict):
+ """Base class for serialized objects."""
+
+ lc: int
+ id: List[str]
+ name: NotRequired[str]
+ graph: NotRequired[Dict[str, Any]]
+
+
+class SerializedConstructor(BaseSerialized):
+ """Serialized constructor."""
+
+ type: Literal["constructor"]
+ kwargs: Dict[str, Any]
+
+
+class SerializedSecret(BaseSerialized):
+ """Serialized secret."""
+
+ type: Literal["secret"]
+
+
+class SerializedNotImplemented(BaseSerialized):
+ """Serialized not implemented."""
+
+ type: Literal["not_implemented"]
+ repr: Optional[str]
+
+
+def try_neq_default(value: Any, key: str, model: BaseModel) -> bool:
+ """Try to determine if a value is different from the default.
+
+ Args:
+ value: The value.
+ key: The key.
+ model: The model.
+
+ Returns:
+ Whether the value is different from the default.
+ """
+ try:
+ return model.__fields__[key].get_default() != value
+ except Exception:
+ return True
+
+
+class Serializable(BaseModel, ABC):
+ """Serializable base class."""
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Is this class serializable?"""
+ return False
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object.
+
+ For example, if the class is `langchain.llms.openai.OpenAI`, then the
+ namespace is ["langchain", "llms", "openai"]
+ """
+ return cls.__module__.split(".")
+
+ @property
+ def lc_secrets(self) -> Dict[str, str]:
+ """A map of constructor argument names to secret ids.
+
+ For example,
+ {"openai_api_key": "OPENAI_API_KEY"}
+ """
+ return dict()
+
+ @property
+ def lc_attributes(self) -> Dict:
+ """List of attribute names that should be included in the serialized kwargs.
+
+ These attributes must be accepted by the constructor.
+ """
+ return {}
+
+ @classmethod
+ def lc_id(cls) -> List[str]:
+ """A unique identifier for this class for serialization purposes.
+
+ The unique identifier is a list of strings that describes the path
+ to the object.
+ """
+ return [*cls.get_lc_namespace(), cls.__name__]
+
+ class Config:
+ extra = "ignore"
+
+ def __repr_args__(self) -> Any:
+ return [
+ (k, v)
+ for k, v in super().__repr_args__()
+ if (k not in self.__fields__ or try_neq_default(v, k, self))
+ ]
+
+ _lc_kwargs = PrivateAttr(default_factory=dict)
+
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self._lc_kwargs = kwargs
+
+ def to_json(
+ self,
+ ) -> Union[SerializedConstructor, SerializedNotImplemented]:
+ if not self.is_lc_serializable():
+ return self.to_json_not_implemented()
+
+ secrets = dict()
+ # Get latest values for kwargs if there is an attribute with same name
+ lc_kwargs = {
+ k: getattr(self, k, v)
+ for k, v in self._lc_kwargs.items()
+ if not (self.__exclude_fields__ or {}).get(k, False) # type: ignore
+ }
+
+ # Merge the lc_secrets and lc_attributes from every class in the MRO
+ for cls in [None, *self.__class__.mro()]:
+ # Once we get to Serializable, we're done
+ if cls is Serializable:
+ break
+
+ if cls:
+ deprecated_attributes = [
+ "lc_namespace",
+ "lc_serializable",
+ ]
+
+ for attr in deprecated_attributes:
+ if hasattr(cls, attr):
+ raise ValueError(
+ f"Class {self.__class__} has a deprecated "
+ f"attribute {attr}. Please use the corresponding "
+ f"classmethod instead."
+ )
+
+ # Get a reference to self bound to each class in the MRO
+ this = cast(
+ Serializable, self if cls is None else super(cls, self)
+ )
+
+ secrets.update(this.lc_secrets)
+ # Now also add the aliases for the secrets
+ # This ensures known secret aliases are hidden.
+ # Note: this does NOT hide any other extra kwargs
+ # that are not present in the fields.
+ for key in list(secrets):
+ value = secrets[key]
+ if key in this.__fields__:
+ secrets[this.__fields__[key].alias] = value
+ lc_kwargs.update(this.lc_attributes)
+
+ # include all secrets, even if not specified in kwargs
+ # as these secrets may be passed as an environment variable instead
+ for key in secrets.keys():
+ secret_value = getattr(self, key, None) or lc_kwargs.get(key)
+ if secret_value is not None:
+ lc_kwargs.update({key: secret_value})
+
+ return {
+ "lc": 1,
+ "type": "constructor",
+ "id": self.lc_id(),
+ "kwargs": (
+ lc_kwargs
+ if not secrets
+ else _replace_secrets(lc_kwargs, secrets)
+ ),
+ }
+
+ def to_json_not_implemented(self) -> SerializedNotImplemented:
+ return to_json_not_implemented(self)
+
+
+def _replace_secrets(
+ root: Dict[Any, Any], secrets_map: Dict[str, str]
+) -> Dict[Any, Any]:
+ result = root.copy()
+ for path, secret_id in secrets_map.items():
+ [*parts, last] = path.split(".")
+ current = result
+ for part in parts:
+ if part not in current:
+ break
+ current[part] = current[part].copy()
+ current = current[part]
+ if last in current:
+ current[last] = {
+ "lc": 1,
+ "type": "secret",
+ "id": [secret_id],
+ }
+ return result
+
+
+def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
+ """Serialize a "not implemented" object.
+
+ Args:
+ obj: object to serialize
+
+ Returns:
+ SerializedNotImplemented
+ """
+ _id: List[str] = []
+ try:
+ if hasattr(obj, "__name__"):
+ _id = [*obj.__module__.split("."), obj.__name__]
+ elif hasattr(obj, "__class__"):
+ _id = [
+ *obj.__class__.__module__.split("."),
+ obj.__class__.__name__,
+ ]
+ except Exception:
+ pass
+
+ result: SerializedNotImplemented = {
+ "lc": 1,
+ "type": "not_implemented",
+ "id": _id,
+ "repr": None,
+ }
+ try:
+ result["repr"] = repr(obj)
+ except Exception:
+ pass
+ return result
+
+
+class Document(Serializable):
+ """Class for storing a piece of text and associated metadata."""
+
+ page_content: str
+ """String text."""
+ metadata: dict = Field(default_factory=dict)
+ """Arbitrary metadata about the page content (e.g., source, relationships to other
+ documents, etc.).
+ """
+ type: Literal["Document"] = "Document"
+
+ def __init__(self, page_content: str, **kwargs: Any) -> None:
+ """Pass page_content in as positional or named arg."""
+ super().__init__(page_content=page_content, **kwargs)
+
+ @classmethod
+ def is_lc_serializable(cls) -> bool:
+ """Return whether this class is serializable."""
+ return True
+
+ @classmethod
+ def get_lc_namespace(cls) -> List[str]:
+ """Get the namespace of the langchain object."""
+ return ["langchain", "schema", "document"]
+
+
+class BaseDocumentTransformer(ABC):
+ """Abstract base class for document transformation systems.
+
+ A document transformation system takes a sequence of Documents and returns a
+ sequence of transformed Documents.
+
+ Example:
+ .. code-block:: python
+
+ class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel):
+ embeddings: Embeddings
+ similarity_fn: Callable = cosine_similarity
+ similarity_threshold: float = 0.95
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ stateful_documents = get_stateful_documents(documents)
+ embedded_documents = _get_embeddings_from_stateful_docs(
+ self.embeddings, stateful_documents
+ )
+ included_idxs = _filter_similar_embeddings(
+ embedded_documents, self.similarity_fn, self.similarity_threshold
+ )
+ return [stateful_documents[i] for i in sorted(included_idxs)]
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ raise NotImplementedError
+
+ """ # noqa: E501
+
+ @abstractmethod
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Transform a list of documents.
+
+ Args:
+ documents: A sequence of Documents to be transformed.
+
+ Returns:
+ A list of transformed Documents.
+ """
+
+ async def atransform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Asynchronously transform a list of documents.
+
+ Args:
+ documents: A sequence of Documents to be transformed.
+
+ Returns:
+ A list of transformed Documents.
+ """
+ raise NotImplementedError("This method is not implemented.")
+ # return await langchain_core.runnables.config.run_in_executor(
+ # None, self.transform_documents, documents, **kwargs
+ # )
+
+
+def _make_spacy_pipe_for_splitting(
+ pipe: str, *, max_length: int = 1_000_000
+) -> Any: # avoid importing spacy
+ try:
+ import spacy
+ except ImportError:
+ raise ImportError(
+ "Spacy is not installed, please install it with `pip install spacy`."
+ )
+ if pipe == "sentencizer":
+ from spacy.lang.en import English
+
+ sentencizer = English()
+ sentencizer.add_pipe("sentencizer")
+ else:
+ sentencizer = spacy.load(pipe, exclude=["ner", "tagger"])
+ sentencizer.max_length = max_length
+ return sentencizer
+
+
+def _split_text_with_regex(
+ text: str, separator: str, keep_separator: bool
+) -> List[str]:
+ # Now that we have the separator, split the text
+ if separator:
+ if keep_separator:
+ # The parentheses in the pattern keep the delimiters in the result.
+ _splits = re.split(f"({separator})", text)
+ splits = [
+ _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)
+ ]
+ if len(_splits) % 2 == 0:
+ splits += _splits[-1:]
+ splits = [_splits[0]] + splits
+ else:
+ splits = re.split(separator, text)
+ else:
+ splits = list(text)
+ return [s for s in splits if s != ""]
+
+
+class TextSplitter(BaseDocumentTransformer, ABC):
+ """Interface for splitting text into chunks."""
+
+ def __init__(
+ self,
+ chunk_size: int = 4000,
+ chunk_overlap: int = 200,
+ length_function: Callable[[str], int] = len,
+ keep_separator: bool = False,
+ add_start_index: bool = False,
+ strip_whitespace: bool = True,
+ ) -> None:
+ """Create a new TextSplitter.
+
+ Args:
+ chunk_size: Maximum size of chunks to return
+ chunk_overlap: Overlap in characters between chunks
+ length_function: Function that measures the length of given chunks
+ keep_separator: Whether to keep the separator in the chunks
+ add_start_index: If `True`, includes chunk's start index in metadata
+ strip_whitespace: If `True`, strips whitespace from the start and end of
+ every document
+ """
+ if chunk_overlap > chunk_size:
+ raise ValueError(
+ f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
+ f"({chunk_size}), should be smaller."
+ )
+ self._chunk_size = chunk_size
+ self._chunk_overlap = chunk_overlap
+ self._length_function = length_function
+ self._keep_separator = keep_separator
+ self._add_start_index = add_start_index
+ self._strip_whitespace = strip_whitespace
+
+ @abstractmethod
+ def split_text(self, text: str) -> List[str]:
+ """Split text into multiple components."""
+
+ def create_documents(
+ self, texts: List[str], metadatas: Optional[List[dict]] = None
+ ) -> List[Document]:
+ """Create documents from a list of texts."""
+ _metadatas = metadatas or [{}] * len(texts)
+ documents = []
+ for i, text in enumerate(texts):
+ index = 0
+ previous_chunk_len = 0
+ for chunk in self.split_text(text):
+ metadata = copy.deepcopy(_metadatas[i])
+ if self._add_start_index:
+ offset = index + previous_chunk_len - self._chunk_overlap
+ index = text.find(chunk, max(0, offset))
+ metadata["start_index"] = index
+ previous_chunk_len = len(chunk)
+ new_doc = Document(page_content=chunk, metadata=metadata)
+ documents.append(new_doc)
+ return documents
+
+ def split_documents(self, documents: Iterable[Document]) -> List[Document]:
+ """Split documents."""
+ texts, metadatas = [], []
+ for doc in documents:
+ texts.append(doc.page_content)
+ metadatas.append(doc.metadata)
+ return self.create_documents(texts, metadatas=metadatas)
+
+ def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
+ text = separator.join(docs)
+ if self._strip_whitespace:
+ text = text.strip()
+ if text == "":
+ return None
+ else:
+ return text
+
+ def _merge_splits(
+ self, splits: Iterable[str], separator: str
+ ) -> List[str]:
+ # We now want to combine these smaller pieces into medium size
+ # chunks to send to the LLM.
+ separator_len = self._length_function(separator)
+
+ docs = []
+ current_doc: List[str] = []
+ total = 0
+ for d in splits:
+ _len = self._length_function(d)
+ if (
+ total + _len + (separator_len if len(current_doc) > 0 else 0)
+ > self._chunk_size
+ ):
+ if total > self._chunk_size:
+ logger.warning(
+ f"Created a chunk of size {total}, "
+ f"which is longer than the specified {self._chunk_size}"
+ )
+ if len(current_doc) > 0:
+ doc = self._join_docs(current_doc, separator)
+ if doc is not None:
+ docs.append(doc)
+ # Keep on popping if:
+ # - we have a larger chunk than in the chunk overlap
+ # - or if we still have any chunks and the length is long
+ while total > self._chunk_overlap or (
+ total
+ + _len
+ + (separator_len if len(current_doc) > 0 else 0)
+ > self._chunk_size
+ and total > 0
+ ):
+ total -= self._length_function(current_doc[0]) + (
+ separator_len if len(current_doc) > 1 else 0
+ )
+ current_doc = current_doc[1:]
+ current_doc.append(d)
+ total += _len + (separator_len if len(current_doc) > 1 else 0)
+ doc = self._join_docs(current_doc, separator)
+ if doc is not None:
+ docs.append(doc)
+ return docs
+
+ @classmethod
+ def from_huggingface_tokenizer(
+ cls, tokenizer: Any, **kwargs: Any
+ ) -> TextSplitter:
+ """Text splitter that uses HuggingFace tokenizer to count length."""
+ try:
+ from transformers import PreTrainedTokenizerBase
+
+ if not isinstance(tokenizer, PreTrainedTokenizerBase):
+ raise ValueError(
+ "Tokenizer received was not an instance of PreTrainedTokenizerBase"
+ )
+
+ def _huggingface_tokenizer_length(text: str) -> int:
+ return len(tokenizer.encode(text))
+
+ except ImportError:
+ raise ValueError(
+ "Could not import transformers python package. "
+ "Please install it with `pip install transformers`."
+ )
+ return cls(length_function=_huggingface_tokenizer_length, **kwargs)
+
+ @classmethod
+ def from_tiktoken_encoder(
+ cls: Type[TS],
+ encoding_name: str = "gpt2",
+ model: Optional[str] = None,
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+ **kwargs: Any,
+ ) -> TS:
+ """Text splitter that uses tiktoken encoder to count length."""
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to calculate max_tokens_for_prompt. "
+ "Please install it with `pip install tiktoken`."
+ )
+
+ if model is not None:
+ enc = tiktoken.encoding_for_model(model)
+ else:
+ enc = tiktoken.get_encoding(encoding_name)
+
+ def _tiktoken_encoder(text: str) -> int:
+ return len(
+ enc.encode(
+ text,
+ allowed_special=allowed_special,
+ disallowed_special=disallowed_special,
+ )
+ )
+
+ if issubclass(cls, TokenTextSplitter):
+ extra_kwargs = {
+ "encoding_name": encoding_name,
+ "model": model,
+ "allowed_special": allowed_special,
+ "disallowed_special": disallowed_special,
+ }
+ kwargs = {**kwargs, **extra_kwargs}
+
+ return cls(length_function=_tiktoken_encoder, **kwargs)
+
+ def transform_documents(
+ self, documents: Sequence[Document], **kwargs: Any
+ ) -> Sequence[Document]:
+ """Transform sequence of documents by splitting them."""
+ return self.split_documents(list(documents))
+
+
+class CharacterTextSplitter(TextSplitter):
+ """Splitting text that looks at characters."""
+
+ def __init__(
+ self,
+ separator: str = "\n\n",
+ is_separator_regex: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs)
+ self._separator = separator
+ self._is_separator_regex = is_separator_regex
+
+ def split_text(self, text: str) -> List[str]:
+ """Split incoming text and return chunks."""
+ # First we naively split the large input into a bunch of smaller ones.
+ separator = (
+ self._separator
+ if self._is_separator_regex
+ else re.escape(self._separator)
+ )
+ splits = _split_text_with_regex(text, separator, self._keep_separator)
+ _separator = "" if self._keep_separator else self._separator
+ return self._merge_splits(splits, _separator)
+
+
+class LineType(TypedDict):
+ """Line type as typed dict."""
+
+ metadata: Dict[str, str]
+ content: str
+
+
+class HeaderType(TypedDict):
+ """Header type as typed dict."""
+
+ level: int
+ name: str
+ data: str
+
+
+class MarkdownHeaderTextSplitter:
+ """Splitting markdown files based on specified headers."""
+
+ def __init__(
+ self,
+ headers_to_split_on: List[Tuple[str, str]],
+ return_each_line: bool = False,
+ strip_headers: bool = True,
+ ):
+ """Create a new MarkdownHeaderTextSplitter.
+
+ Args:
+ headers_to_split_on: Headers we want to track
+ return_each_line: Return each line w/ associated headers
+ strip_headers: Strip split headers from the content of the chunk
+ """
+ # Output line-by-line or aggregated into chunks w/ common headers
+ self.return_each_line = return_each_line
+ # Given the headers we want to split on,
+ # (e.g., "#, ##, etc") order by length
+ self.headers_to_split_on = sorted(
+ headers_to_split_on, key=lambda split: len(split[0]), reverse=True
+ )
+ # Strip headers split headers from the content of the chunk
+ self.strip_headers = strip_headers
+
+ def aggregate_lines_to_chunks(
+ self, lines: List[LineType]
+ ) -> List[Document]:
+ """Combine lines with common metadata into chunks
+ Args:
+ lines: Line of text / associated header metadata
+ """
+ aggregated_chunks: List[LineType] = []
+
+ for line in lines:
+ if (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] == line["metadata"]
+ ):
+ # If the last line in the aggregated list
+ # has the same metadata as the current line,
+ # append the current content to the last lines's content
+ aggregated_chunks[-1]["content"] += " \n" + line["content"]
+ elif (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] != line["metadata"]
+ # may be issues if other metadata is present
+ and len(aggregated_chunks[-1]["metadata"])
+ < len(line["metadata"])
+ and aggregated_chunks[-1]["content"].split("\n")[-1][0] == "#"
+ and not self.strip_headers
+ ):
+ # If the last line in the aggregated list
+ # has different metadata as the current line,
+ # and has shallower header level than the current line,
+ # and the last line is a header,
+ # and we are not stripping headers,
+ # append the current content to the last line's content
+ aggregated_chunks[-1]["content"] += " \n" + line["content"]
+ # and update the last line's metadata
+ aggregated_chunks[-1]["metadata"] = line["metadata"]
+ else:
+ # Otherwise, append the current line to the aggregated list
+ aggregated_chunks.append(line)
+
+ return [
+ Document(page_content=chunk["content"], metadata=chunk["metadata"])
+ for chunk in aggregated_chunks
+ ]
+
+ def split_text(self, text: str) -> List[Document]:
+ """Split markdown file
+ Args:
+ text: Markdown file"""
+
+ # Split the input text by newline character ("\n").
+ lines = text.split("\n")
+ # Final output
+ lines_with_metadata: List[LineType] = []
+ # Content and metadata of the chunk currently being processed
+ current_content: List[str] = []
+ current_metadata: Dict[str, str] = {}
+ # Keep track of the nested header structure
+ # header_stack: List[Dict[str, Union[int, str]]] = []
+ header_stack: List[HeaderType] = []
+ initial_metadata: Dict[str, str] = {}
+
+ in_code_block = False
+ opening_fence = ""
+
+ for line in lines:
+ stripped_line = line.strip()
+
+ if not in_code_block:
+ # Exclude inline code spans
+ if (
+ stripped_line.startswith("```")
+ and stripped_line.count("```") == 1
+ ):
+ in_code_block = True
+ opening_fence = "```"
+ elif stripped_line.startswith("~~~"):
+ in_code_block = True
+ opening_fence = "~~~"
+ else:
+ if stripped_line.startswith(opening_fence):
+ in_code_block = False
+ opening_fence = ""
+
+ if in_code_block:
+ current_content.append(stripped_line)
+ continue
+
+ # Check each line against each of the header types (e.g., #, ##)
+ for sep, name in self.headers_to_split_on:
+ # Check if line starts with a header that we intend to split on
+ if stripped_line.startswith(sep) and (
+ # Header with no text OR header is followed by space
+ # Both are valid conditions that sep is being used a header
+ len(stripped_line) == len(sep)
+ or stripped_line[len(sep)] == " "
+ ):
+ # Ensure we are tracking the header as metadata
+ if name is not None:
+ # Get the current header level
+ current_header_level = sep.count("#")
+
+ # Pop out headers of lower or same level from the stack
+ while (
+ header_stack
+ and header_stack[-1]["level"]
+ >= current_header_level
+ ):
+ # We have encountered a new header
+ # at the same or higher level
+ popped_header = header_stack.pop()
+ # Clear the metadata for the
+ # popped header in initial_metadata
+ if popped_header["name"] in initial_metadata:
+ initial_metadata.pop(popped_header["name"])
+
+ # Push the current header to the stack
+ header: HeaderType = {
+ "level": current_header_level,
+ "name": name,
+ "data": stripped_line[len(sep) :].strip(),
+ }
+ header_stack.append(header)
+ # Update initial_metadata with the current header
+ initial_metadata[name] = header["data"]
+
+ # Add the previous line to the lines_with_metadata
+ # only if current_content is not empty
+ if current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata.copy(),
+ }
+ )
+ current_content.clear()
+
+ if not self.strip_headers:
+ current_content.append(stripped_line)
+
+ break
+ else:
+ if stripped_line:
+ current_content.append(stripped_line)
+ elif current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata.copy(),
+ }
+ )
+ current_content.clear()
+
+ current_metadata = initial_metadata.copy()
+
+ if current_content:
+ lines_with_metadata.append(
+ {
+ "content": "\n".join(current_content),
+ "metadata": current_metadata,
+ }
+ )
+
+ # lines_with_metadata has each line with associated header metadata
+ # aggregate these into chunks based on common metadata
+ if not self.return_each_line:
+ return self.aggregate_lines_to_chunks(lines_with_metadata)
+ else:
+ return [
+ Document(
+ page_content=chunk["content"], metadata=chunk["metadata"]
+ )
+ for chunk in lines_with_metadata
+ ]
+
+
+class ElementType(TypedDict):
+ """Element type as typed dict."""
+
+ url: str
+ xpath: str
+ content: str
+ metadata: Dict[str, str]
+
+
+class HTMLHeaderTextSplitter:
+ """
+ Splitting HTML files based on specified headers.
+ Requires lxml package.
+ """
+
+ def __init__(
+ self,
+ headers_to_split_on: List[Tuple[str, str]],
+ return_each_element: bool = False,
+ ):
+ """Create a new HTMLHeaderTextSplitter.
+
+ Args:
+ headers_to_split_on: list of tuples of headers we want to track mapped to
+ (arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4,
+ h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)].
+ return_each_element: Return each element w/ associated headers.
+ """
+ # Output element-by-element or aggregated into chunks w/ common headers
+ self.return_each_element = return_each_element
+ self.headers_to_split_on = sorted(headers_to_split_on)
+
+ def aggregate_elements_to_chunks(
+ self, elements: List[ElementType]
+ ) -> List[Document]:
+ """Combine elements with common metadata into chunks
+
+ Args:
+ elements: HTML element content with associated identifying info and metadata
+ """
+ aggregated_chunks: List[ElementType] = []
+
+ for element in elements:
+ if (
+ aggregated_chunks
+ and aggregated_chunks[-1]["metadata"] == element["metadata"]
+ ):
+ # If the last element in the aggregated list
+ # has the same metadata as the current element,
+ # append the current content to the last element's content
+ aggregated_chunks[-1]["content"] += " \n" + element["content"]
+ else:
+ # Otherwise, append the current element to the aggregated list
+ aggregated_chunks.append(element)
+
+ return [
+ Document(page_content=chunk["content"], metadata=chunk["metadata"])
+ for chunk in aggregated_chunks
+ ]
+
+ def split_text_from_url(self, url: str) -> List[Document]:
+ """Split HTML from web URL
+
+ Args:
+ url: web URL
+ """
+ r = requests.get(url)
+ return self.split_text_from_file(BytesIO(r.content))
+
+ def split_text(self, text: str) -> List[Document]:
+ """Split HTML text string
+
+ Args:
+ text: HTML text
+ """
+ return self.split_text_from_file(StringIO(text))
+
+ def split_text_from_file(self, file: Any) -> List[Document]:
+ """Split HTML file
+
+ Args:
+ file: HTML file
+ """
+ try:
+ from lxml import etree
+ except ImportError as e:
+ raise ImportError(
+ "Unable to import lxml, please install with `pip install lxml`."
+ ) from e
+ # use lxml library to parse html document and return xml ElementTree
+ # Explicitly encoding in utf-8 allows non-English
+ # html files to be processed without garbled characters
+ parser = etree.HTMLParser(encoding="utf-8")
+ tree = etree.parse(file, parser)
+
+ # document transformation for "structure-aware" chunking is handled with xsl.
+ # see comments in html_chunks_with_headers.xslt for more detailed information.
+ xslt_path = (
+ pathlib.Path(__file__).parent
+ / "document_transformers/xsl/html_chunks_with_headers.xslt"
+ )
+ xslt_tree = etree.parse(xslt_path)
+ transform = etree.XSLT(xslt_tree)
+ result = transform(tree)
+ result_dom = etree.fromstring(str(result))
+
+ # create filter and mapping for header metadata
+ header_filter = [header[0] for header in self.headers_to_split_on]
+ header_mapping = dict(self.headers_to_split_on)
+
+ # map xhtml namespace prefix
+ ns_map = {"h": "http://www.w3.org/1999/xhtml"}
+
+ # build list of elements from DOM
+ elements = []
+ for element in result_dom.findall("*//*", ns_map):
+ if element.findall("*[@class='headers']") or element.findall(
+ "*[@class='chunk']"
+ ):
+ elements.append(
+ ElementType(
+ url=file,
+ xpath="".join(
+ [
+ node.text
+ for node in element.findall(
+ "*[@class='xpath']", ns_map
+ )
+ ]
+ ),
+ content="".join(
+ [
+ node.text
+ for node in element.findall(
+ "*[@class='chunk']", ns_map
+ )
+ ]
+ ),
+ metadata={
+ # Add text of specified headers to metadata using header
+ # mapping.
+ header_mapping[node.tag]: node.text
+ for node in filter(
+ lambda x: x.tag in header_filter,
+ element.findall(
+ "*[@class='headers']/*", ns_map
+ ),
+ )
+ },
+ )
+ )
+
+ if not self.return_each_element:
+ return self.aggregate_elements_to_chunks(elements)
+ else:
+ return [
+ Document(
+ page_content=chunk["content"], metadata=chunk["metadata"]
+ )
+ for chunk in elements
+ ]
+
+
+# should be in newer Python versions (3.10+)
+# @dataclass(frozen=True, kw_only=True, slots=True)
+@dataclass(frozen=True)
+class Tokenizer:
+ """Tokenizer data class."""
+
+ chunk_overlap: int
+ """Overlap in tokens between chunks"""
+ tokens_per_chunk: int
+ """Maximum number of tokens per chunk"""
+ decode: Callable[[List[int]], str]
+ """ Function to decode a list of token ids to a string"""
+ encode: Callable[[str], List[int]]
+ """ Function to encode a string to a list of token ids"""
+
+
+def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
+ """Split incoming text and return chunks using tokenizer."""
+ splits: List[str] = []
+ input_ids = tokenizer.encode(text)
+ start_idx = 0
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
+ chunk_ids = input_ids[start_idx:cur_idx]
+ while start_idx < len(input_ids):
+ splits.append(tokenizer.decode(chunk_ids))
+ if cur_idx == len(input_ids):
+ break
+ start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
+ cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
+ chunk_ids = input_ids[start_idx:cur_idx]
+ return splits
+
+
+class TokenTextSplitter(TextSplitter):
+ """Splitting text to tokens using model tokenizer."""
+
+ def __init__(
+ self,
+ encoding_name: str = "gpt2",
+ model: Optional[str] = None,
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs)
+ try:
+ import tiktoken
+ except ImportError:
+ raise ImportError(
+ "Could not import tiktoken python package. "
+ "This is needed in order to for TokenTextSplitter. "
+ "Please install it with `pip install tiktoken`."
+ )
+
+ if model is not None:
+ enc = tiktoken.encoding_for_model(model)
+ else:
+ enc = tiktoken.get_encoding(encoding_name)
+ self._tokenizer = enc
+ self._allowed_special = allowed_special
+ self._disallowed_special = disallowed_special
+
+ def split_text(self, text: str) -> List[str]:
+ def _encode(_text: str) -> List[int]:
+ return self._tokenizer.encode(
+ _text,
+ allowed_special=self._allowed_special,
+ disallowed_special=self._disallowed_special,
+ )
+
+ tokenizer = Tokenizer(
+ chunk_overlap=self._chunk_overlap,
+ tokens_per_chunk=self._chunk_size,
+ decode=self._tokenizer.decode,
+ encode=_encode,
+ )
+
+ return split_text_on_tokens(text=text, tokenizer=tokenizer)
+
+
+class SentenceTransformersTokenTextSplitter(TextSplitter):
+ """Splitting text to tokens using sentence model tokenizer."""
+
+ def __init__(
+ self,
+ chunk_overlap: int = 50,
+ model: str = "sentence-transformers/all-mpnet-base-v2",
+ tokens_per_chunk: Optional[int] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(**kwargs, chunk_overlap=chunk_overlap)
+
+ try:
+ from sentence_transformers import SentenceTransformer
+ except ImportError:
+ raise ImportError(
+ "Could not import sentence_transformer python package. "
+ "This is needed in order to for SentenceTransformersTokenTextSplitter. "
+ "Please install it with `pip install sentence-transformers`."
+ )
+
+ self.model = model
+ self._model = SentenceTransformer(self.model, trust_remote_code=True)
+ self.tokenizer = self._model.tokenizer
+ self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
+
+ def _initialize_chunk_configuration(
+ self, *, tokens_per_chunk: Optional[int]
+ ) -> None:
+ self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
+
+ if tokens_per_chunk is None:
+ self.tokens_per_chunk = self.maximum_tokens_per_chunk
+ else:
+ self.tokens_per_chunk = tokens_per_chunk
+
+ if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
+ raise ValueError(
+ f"The token limit of the models '{self.model}'"
+ f" is: {self.maximum_tokens_per_chunk}."
+ f" Argument tokens_per_chunk={self.tokens_per_chunk}"
+ f" > maximum token limit."
+ )
+
+ def split_text(self, text: str) -> List[str]:
+ def encode_strip_start_and_stop_token_ids(text: str) -> List[int]:
+ return self._encode(text)[1:-1]
+
+ tokenizer = Tokenizer(
+ chunk_overlap=self._chunk_overlap,
+ tokens_per_chunk=self.tokens_per_chunk,
+ decode=self.tokenizer.decode,
+ encode=encode_strip_start_and_stop_token_ids,
+ )
+
+ return split_text_on_tokens(text=text, tokenizer=tokenizer)
+
+ def count_tokens(self, *, text: str) -> int:
+ return len(self._encode(text))
+
+ _max_length_equal_32_bit_integer: int = 2**32
+
+ def _encode(self, text: str) -> List[int]:
+ token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
+ text,
+ max_length=self._max_length_equal_32_bit_integer,
+ truncation="do_not_truncate",
+ )
+ return token_ids_with_start_and_end_token_ids
+
+
+class Language(str, Enum):
+ """Enum of the programming languages."""
+
+ CPP = "cpp"
+ GO = "go"
+ JAVA = "java"
+ KOTLIN = "kotlin"
+ JS = "js"
+ TS = "ts"
+ PHP = "php"
+ PROTO = "proto"
+ PYTHON = "python"
+ RST = "rst"
+ RUBY = "ruby"
+ RUST = "rust"
+ SCALA = "scala"
+ SWIFT = "swift"
+ MARKDOWN = "markdown"
+ LATEX = "latex"
+ HTML = "html"
+ SOL = "sol"
+ CSHARP = "csharp"
+ COBOL = "cobol"
+ C = "c"
+ LUA = "lua"
+ PERL = "perl"
+
+
+class RecursiveCharacterTextSplitter(TextSplitter):
+ """Splitting text by recursively look at characters.
+
+ Recursively tries to split by different characters to find one
+ that works.
+ """
+
+ def __init__(
+ self,
+ separators: Optional[List[str]] = None,
+ keep_separator: bool = True,
+ is_separator_regex: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """Create a new TextSplitter."""
+ super().__init__(keep_separator=keep_separator, **kwargs)
+ self._separators = separators or ["\n\n", "\n", " ", ""]
+ self._is_separator_regex = is_separator_regex
+
+ def _split_text(self, text: str, separators: List[str]) -> List[str]:
+ """Split incoming text and return chunks."""
+ final_chunks = []
+ # Get appropriate separator to use
+ separator = separators[-1]
+ new_separators = []
+ for i, _s in enumerate(separators):
+ _separator = _s if self._is_separator_regex else re.escape(_s)
+ if _s == "":
+ separator = _s
+ break
+ if re.search(_separator, text):
+ separator = _s
+ new_separators = separators[i + 1 :]
+ break
+
+ _separator = (
+ separator if self._is_separator_regex else re.escape(separator)
+ )
+ splits = _split_text_with_regex(text, _separator, self._keep_separator)
+
+ # Now go merging things, recursively splitting longer texts.
+ _good_splits = []
+ _separator = "" if self._keep_separator else separator
+ for s in splits:
+ if self._length_function(s) < self._chunk_size:
+ _good_splits.append(s)
+ else:
+ if _good_splits:
+ merged_text = self._merge_splits(_good_splits, _separator)
+ final_chunks.extend(merged_text)
+ _good_splits = []
+ if not new_separators:
+ final_chunks.append(s)
+ else:
+ other_info = self._split_text(s, new_separators)
+ final_chunks.extend(other_info)
+ if _good_splits:
+ merged_text = self._merge_splits(_good_splits, _separator)
+ final_chunks.extend(merged_text)
+ return final_chunks
+
+ def split_text(self, text: str) -> List[str]:
+ return self._split_text(text, self._separators)
+
+ @classmethod
+ def from_language(
+ cls, language: Language, **kwargs: Any
+ ) -> RecursiveCharacterTextSplitter:
+ separators = cls.get_separators_for_language(language)
+ return cls(separators=separators, is_separator_regex=True, **kwargs)
+
+ @staticmethod
+ def get_separators_for_language(language: Language) -> List[str]:
+ if language == Language.CPP:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along function definitions
+ "\nvoid ",
+ "\nint ",
+ "\nfloat ",
+ "\ndouble ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.GO:
+ return [
+ # Split along function definitions
+ "\nfunc ",
+ "\nvar ",
+ "\nconst ",
+ "\ntype ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.JAVA:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\nstatic ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.KOTLIN:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\ninternal ",
+ "\ncompanion ",
+ "\nfun ",
+ "\nval ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nwhen ",
+ "\ncase ",
+ "\nelse ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.JS:
+ return [
+ # Split along function definitions
+ "\nfunction ",
+ "\nconst ",
+ "\nlet ",
+ "\nvar ",
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ "\ndefault ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.TS:
+ return [
+ "\nenum ",
+ "\ninterface ",
+ "\nnamespace ",
+ "\ntype ",
+ # Split along class definitions
+ "\nclass ",
+ # Split along function definitions
+ "\nfunction ",
+ "\nconst ",
+ "\nlet ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nswitch ",
+ "\ncase ",
+ "\ndefault ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PHP:
+ return [
+ # Split along function definitions
+ "\nfunction ",
+ # Split along class definitions
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nforeach ",
+ "\nwhile ",
+ "\ndo ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PROTO:
+ return [
+ # Split along message definitions
+ "\nmessage ",
+ # Split along service definitions
+ "\nservice ",
+ # Split along enum definitions
+ "\nenum ",
+ # Split along option definitions
+ "\noption ",
+ # Split along import statements
+ "\nimport ",
+ # Split along syntax declarations
+ "\nsyntax ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.PYTHON:
+ return [
+ # First, try to split along class definitions
+ "\nclass ",
+ "\ndef ",
+ "\n\tdef ",
+ # Now split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RST:
+ return [
+ # Split along section titles
+ "\n=+\n",
+ "\n-+\n",
+ "\n\\*+\n",
+ # Split along directive markers
+ "\n\n.. *\n\n",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RUBY:
+ return [
+ # Split along method definitions
+ "\ndef ",
+ "\nclass ",
+ # Split along control flow statements
+ "\nif ",
+ "\nunless ",
+ "\nwhile ",
+ "\nfor ",
+ "\ndo ",
+ "\nbegin ",
+ "\nrescue ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.RUST:
+ return [
+ # Split along function definitions
+ "\nfn ",
+ "\nconst ",
+ "\nlet ",
+ # Split along control flow statements
+ "\nif ",
+ "\nwhile ",
+ "\nfor ",
+ "\nloop ",
+ "\nmatch ",
+ "\nconst ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SCALA:
+ return [
+ # Split along class definitions
+ "\nclass ",
+ "\nobject ",
+ # Split along method definitions
+ "\ndef ",
+ "\nval ",
+ "\nvar ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\nmatch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SWIFT:
+ return [
+ # Split along function definitions
+ "\nfunc ",
+ # Split along class definitions
+ "\nclass ",
+ "\nstruct ",
+ "\nenum ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\ndo ",
+ "\nswitch ",
+ "\ncase ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.MARKDOWN:
+ return [
+ # First, try to split along Markdown headings (starting with level 2)
+ "\n#{1,6} ",
+ # Note the alternative syntax for headings (below) is not handled here
+ # Heading level 2
+ # ---------------
+ # End of code block
+ "```\n",
+ # Horizontal lines
+ "\n\\*\\*\\*+\n",
+ "\n---+\n",
+ "\n___+\n",
+ # Note that this splitter doesn't handle horizontal lines defined
+ # by *three or more* of ***, ---, or ___, but this is not handled
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.LATEX:
+ return [
+ # First, try to split along Latex sections
+ "\n\\\\chapter{",
+ "\n\\\\section{",
+ "\n\\\\subsection{",
+ "\n\\\\subsubsection{",
+ # Now split by environments
+ "\n\\\\begin{enumerate}",
+ "\n\\\\begin{itemize}",
+ "\n\\\\begin{description}",
+ "\n\\\\begin{list}",
+ "\n\\\\begin{quote}",
+ "\n\\\\begin{quotation}",
+ "\n\\\\begin{verse}",
+ "\n\\\\begin{verbatim}",
+ # Now split by math environments
+ "\n\\\begin{align}",
+ "$$",
+ "$",
+ # Now split by the normal type of lines
+ " ",
+ "",
+ ]
+ elif language == Language.HTML:
+ return [
+ # First, try to split along HTML tags
+ "<body",
+ "<div",
+ "<p",
+ "<br",
+ "<li",
+ "<h1",
+ "<h2",
+ "<h3",
+ "<h4",
+ "<h5",
+ "<h6",
+ "<span",
+ "<table",
+ "<tr",
+ "<td",
+ "<th",
+ "<ul",
+ "<ol",
+ "<header",
+ "<footer",
+ "<nav",
+ # Head
+ "<head",
+ "<style",
+ "<script",
+ "<meta",
+ "<title",
+ "",
+ ]
+ elif language == Language.CSHARP:
+ return [
+ "\ninterface ",
+ "\nenum ",
+ "\nimplements ",
+ "\ndelegate ",
+ "\nevent ",
+ # Split along class definitions
+ "\nclass ",
+ "\nabstract ",
+ # Split along method definitions
+ "\npublic ",
+ "\nprotected ",
+ "\nprivate ",
+ "\nstatic ",
+ "\nreturn ",
+ # Split along control flow statements
+ "\nif ",
+ "\ncontinue ",
+ "\nfor ",
+ "\nforeach ",
+ "\nwhile ",
+ "\nswitch ",
+ "\nbreak ",
+ "\ncase ",
+ "\nelse ",
+ # Split by exceptions
+ "\ntry ",
+ "\nthrow ",
+ "\nfinally ",
+ "\ncatch ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.SOL:
+ return [
+ # Split along compiler information definitions
+ "\npragma ",
+ "\nusing ",
+ # Split along contract definitions
+ "\ncontract ",
+ "\ninterface ",
+ "\nlibrary ",
+ # Split along method definitions
+ "\nconstructor ",
+ "\ntype ",
+ "\nfunction ",
+ "\nevent ",
+ "\nmodifier ",
+ "\nerror ",
+ "\nstruct ",
+ "\nenum ",
+ # Split along control flow statements
+ "\nif ",
+ "\nfor ",
+ "\nwhile ",
+ "\ndo while ",
+ "\nassembly ",
+ # Split by the normal type of lines
+ "\n\n",
+ "\n",
+ " ",
+ "",
+ ]
+ elif language == Language.COBOL:
+ return [
+ # Split along divisions
+ "\nIDENTIFICATION DIVISION.",
+ "\nENVIRONMENT DIVISION.",
+ "\nDATA DIVISION.",
+ "\nPROCEDURE DIVISION.",
+ # Split along sections within DATA DIVISION
+ "\nWORKING-STORAGE SECTION.",
+ "\nLINKAGE SECTION.",
+ "\nFILE SECTION.",
+ # Split along sections within PROCEDURE DIVISION
+ "\nINPUT-OUTPUT SECTION.",
+ # Split along paragraphs and common statements
+ "\nOPEN ",
+ "\nCLOSE ",
+ "\nREAD ",
+ "\nWRITE ",
+ "\nIF ",
+ "\nELSE ",
+ "\nMOVE ",
+ "\nPERFORM ",
+ "\nUNTIL ",
+ "\nVARYING ",
+ "\nACCEPT ",
+ "\nDISPLAY ",
+ "\nSTOP RUN.",
+ # Split by the normal type of lines
+ "\n",
+ " ",
+ "",
+ ]
+
+ else:
+ raise ValueError(
+ f"Language {language} is not supported! "
+ f"Please choose from {list(Language)}"
+ )
+
+
+class NLTKTextSplitter(TextSplitter):
+ """Splitting text using NLTK package."""
+
+ def __init__(
+ self, separator: str = "\n\n", language: str = "english", **kwargs: Any
+ ) -> None:
+ """Initialize the NLTK splitter."""
+ super().__init__(**kwargs)
+ try:
+ from nltk.tokenize import sent_tokenize
+
+ self._tokenizer = sent_tokenize
+ except ImportError:
+ raise ImportError(
+ "NLTK is not installed, please install it with `pip install nltk`."
+ )
+ self._separator = separator
+ self._language = language
+
+ def split_text(self, text: str) -> List[str]:
+ """Split incoming text and return chunks."""
+ # First we naively split the large input into a bunch of smaller ones.
+ splits = self._tokenizer(text, language=self._language)
+ return self._merge_splits(splits, self._separator)
+
+
+class SpacyTextSplitter(TextSplitter):
+ """Splitting text using Spacy package.
+
+
+ Per default, Spacy's `en_core_web_sm` model is used and
+ its default max_length is 1000000 (it is the length of maximum character
+ this model takes which can be increased for large files). For a faster, but
+ potentially less accurate splitting, you can use `pipe='sentencizer'`.
+ """
+
+ def __init__(
+ self,
+ separator: str = "\n\n",
+ pipe: str = "en_core_web_sm",
+ max_length: int = 1_000_000,
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the spacy text splitter."""
+ super().__init__(**kwargs)
+ self._tokenizer = _make_spacy_pipe_for_splitting(
+ pipe, max_length=max_length
+ )
+ self._separator = separator
+
+ def split_text(self, text: str) -> List[str]:
+ """Split incoming text and return chunks."""
+ splits = (s.text for s in self._tokenizer(text).sents)
+ return self._merge_splits(splits, self._separator)
+
+
+class KonlpyTextSplitter(TextSplitter):
+ """Splitting text using Konlpy package.
+
+ It is good for splitting Korean text.
+ """
+
+ def __init__(
+ self,
+ separator: str = "\n\n",
+ **kwargs: Any,
+ ) -> None:
+ """Initialize the Konlpy text splitter."""
+ super().__init__(**kwargs)
+ self._separator = separator
+ try:
+ from konlpy.tag import Kkma
+ except ImportError:
+ raise ImportError(
+ """
+ Konlpy is not installed, please install it with
+ `pip install konlpy`
+ """
+ )
+ self.kkma = Kkma()
+
+ def split_text(self, text: str) -> List[str]:
+ """Split incoming text and return chunks."""
+ splits = self.kkma.sentences(text)
+ return self._merge_splits(splits, self._separator)
+
+
+# For backwards compatibility
+class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Python syntax."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a PythonCodeTextSplitter."""
+ separators = self.get_separators_for_language(Language.PYTHON)
+ super().__init__(separators=separators, **kwargs)
+
+
+class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Markdown-formatted headings."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a MarkdownTextSplitter."""
+ separators = self.get_separators_for_language(Language.MARKDOWN)
+ super().__init__(separators=separators, **kwargs)
+
+
+class LatexTextSplitter(RecursiveCharacterTextSplitter):
+ """Attempts to split the text along Latex-formatted layout elements."""
+
+ def __init__(self, **kwargs: Any) -> None:
+ """Initialize a LatexTextSplitter."""
+ separators = self.get_separators_for_language(Language.LATEX)
+ super().__init__(separators=separators, **kwargs)
+
+
+class RecursiveJsonSplitter:
+ def __init__(
+ self, max_chunk_size: int = 2000, min_chunk_size: Optional[int] = None
+ ):
+ super().__init__()
+ self.max_chunk_size = max_chunk_size
+ self.min_chunk_size = (
+ min_chunk_size
+ if min_chunk_size is not None
+ else max(max_chunk_size - 200, 50)
+ )
+
+ @staticmethod
+ def _json_size(data: Dict) -> int:
+ """Calculate the size of the serialized JSON object."""
+ return len(json.dumps(data))
+
+ @staticmethod
+ def _set_nested_dict(d: Dict, path: List[str], value: Any) -> None:
+ """Set a value in a nested dictionary based on the given path."""
+ for key in path[:-1]:
+ d = d.setdefault(key, {})
+ d[path[-1]] = value
+
+ def _list_to_dict_preprocessing(self, data: Any) -> Any:
+ if isinstance(data, dict):
+ # Process each key-value pair in the dictionary
+ return {
+ k: self._list_to_dict_preprocessing(v) for k, v in data.items()
+ }
+ elif isinstance(data, list):
+ # Convert the list to a dictionary with index-based keys
+ return {
+ str(i): self._list_to_dict_preprocessing(item)
+ for i, item in enumerate(data)
+ }
+ else:
+ # Base case: the item is neither a dict nor a list, so return it unchanged
+ return data
+
+ def _json_split(
+ self,
+ data: Dict[str, Any],
+ current_path: List[str] = [],
+ chunks: List[Dict] = [{}],
+ ) -> List[Dict]:
+ """
+ Split json into maximum size dictionaries while preserving structure.
+ """
+ if isinstance(data, dict):
+ for key, value in data.items():
+ new_path = current_path + [key]
+ chunk_size = self._json_size(chunks[-1])
+ size = self._json_size({key: value})
+ remaining = self.max_chunk_size - chunk_size
+
+ if size < remaining:
+ # Add item to current chunk
+ self._set_nested_dict(chunks[-1], new_path, value)
+ else:
+ if chunk_size >= self.min_chunk_size:
+ # Chunk is big enough, start a new chunk
+ chunks.append({})
+
+ # Iterate
+ self._json_split(value, new_path, chunks)
+ else:
+ # handle single item
+ self._set_nested_dict(chunks[-1], current_path, data)
+ return chunks
+
+ def split_json(
+ self,
+ json_data: Dict[str, Any],
+ convert_lists: bool = False,
+ ) -> List[Dict]:
+ """Splits JSON into a list of JSON chunks"""
+
+ if convert_lists:
+ chunks = self._json_split(
+ self._list_to_dict_preprocessing(json_data)
+ )
+ else:
+ chunks = self._json_split(json_data)
+
+ # Remove the last chunk if it's empty
+ if not chunks[-1]:
+ chunks.pop()
+ return chunks
+
+ def split_text(
+ self, json_data: Dict[str, Any], convert_lists: bool = False
+ ) -> List[str]:
+ """Splits JSON into a list of JSON formatted strings"""
+
+ chunks = self.split_json(
+ json_data=json_data, convert_lists=convert_lists
+ )
+
+ # Convert to string
+ return [json.dumps(chunk) for chunk in chunks]
+
+ def create_documents(
+ self,
+ texts: List[Dict],
+ convert_lists: bool = False,
+ metadatas: Optional[List[dict]] = None,
+ ) -> List[Document]:
+ """Create documents from a list of json objects (Dict)."""
+ _metadatas = metadatas or [{}] * len(texts)
+ documents = []
+ for i, text in enumerate(texts):
+ for chunk in self.split_text(
+ json_data=text, convert_lists=convert_lists
+ ):
+ metadata = copy.deepcopy(_metadatas[i])
+ new_doc = Document(page_content=chunk, metadata=metadata)
+ documents.append(new_doc)
+ return documents