活在梦里

Langchain笔记:Textsplitter

RecursiveCharacterTextSplitter

一直好奇所谓的递归分割是怎么个递归法,打开源码一探究竟。

源码位置:langchain_text_splitters.character.RecursiveCharacterTextSplitter

类定义和初始化

class RecursiveCharacterTextSplitter(TextSplitter):
    """Splitting text by recursively look at characters.

    Recursively tries to split by different characters to find one
    that works.
    """

    def __init__(
        self,
        separators: Optional[List[str]] = None,
        keep_separator: bool = True,
        is_separator_regex: bool = False,
        **kwargs: Any,
    ) -> None:
        """Create a new TextSplitter."""
        super().__init__(keep_separator=keep_separator, **kwargs)
        self._separators = separators or ["\n\n", "\n", " ", ""]
        self._is_separator_regex = is_separator_regex

__init__ 方法用于初始化对象,接受以下参数:

  • separators: 一个可选的字符串列表,用于指定分割文本的分隔符。默认值为 ["\n\n", "\n", " ", ""]
  • keep_separator: 一个布尔值,表示是否保留分隔符。默认值为 True
  • is_separator_regex: 一个布尔值,表示分隔符是否为正则表达式。默认值为 False
  • **kwargs: 其他关键字参数,传递给父类的初始化方法。

文本分割方法

    def _split_text(self, text: str, separators: List[str]) -> List[str]:
        """Split incoming text and return chunks."""
        final_chunks = []
		# 从分隔符列表中遍历出第一个可以(即文本段中要能搜到这个分隔符)用于分割文本段的分隔符。
		# 记录剩余的分隔符为 new_separators。
		# 即:先尝试用"\n\n"分割,再"\n",再" ",最后""兜底就是不分了。
        # Get appropriate separator to use
        separator = separators[-1]
        new_separators = []
        for i, _s in enumerate(separators):
            _separator = _s if self._is_separator_regex else re.escape(_s)
            if _s == "":
                separator = _s
                break
            if re.search(_separator, text):
                separator = _s
                new_separators = separators[i + 1 :]
                break

		# 将正则分隔符和非正则分隔符统一为正则表达式的形式,然后用正则的方法对文档进行分割。
        _separator = separator if self._is_separator_regex else re.escape(separator)
        splits = _split_text_with_regex(text, _separator, self._keep_separator)

		# 初始化 _good_splits 用于存储长度小于 _chunk_size 的文本块。
        # Now go merging things, recursively splitting longer texts.
        _good_splits = []
        _separator = "" if self._keep_separator else separator
        for s in splits:
            if self._length_function(s) < self._chunk_size:
                _good_splits.append(s)
            else:
				# 如果文本块长度大于或等于 _chunk_size,则合并 _good_splits 中的文本块,并清空 _good_splits。
                if _good_splits:
                    merged_text = self._merge_splits(_good_splits, _separator)
                    final_chunks.extend(merged_text)
                    _good_splits = []
				# 如果没有剩余的分隔符 (new_separators 为空),则直接将当前文本块添加到 final_chunks。
                if not new_separators:
                    final_chunks.append(s)
				# 如果还有剩余的分隔符,递归调用 _split_text 方法继续分割当前文本块。
                else:
                    other_info = self._split_text(s, new_separators)
                    final_chunks.extend(other_info)
        if _good_splits:
            merged_text = self._merge_splits(_good_splits, _separator)
            final_chunks.extend(merged_text)
        return final_chunks

所以,所谓递归分割的方法就是:按照分隔符的优先级进行分割。当使用高优先级分割出来的某个文档块超过指定长度,会对这个超长的文本块递归调用低优先级的分隔符继续进行分割。

other_info = self._split_text(s, new_separators)

使用正则表达式根据指定的分隔符分割文本

_split_text_with_regex 函数通过正则表达式根据指定的分隔符分割文本,并根据 keep_separator 参数决定是否保留分隔符。该函数在处理过程中会合并文本块和分隔符(如果需要保留分隔符),并过滤掉空字符串,最终返回一个非空的分割结果列表。

def _split_text_with_regex(
    text: str, separator: str, keep_separator: bool
) -> List[str]:
    # Now that we have the separator, split the text
    if separator:
        if keep_separator:
            # The parentheses in the pattern keep the delimiters in the result.
            _splits = re.split(f"({separator})", text)
			# 通过列表推导式 [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] 将文本块和分隔符合并。
            splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
			# 如果 _splits 的长度为偶数,说明最后一个元素是单独的文本块,需要添加到 splits 中。
            if len(_splits) % 2 == 0:
                splits += _splits[-1:]
			# 将第一个元素(原始文本的第一个部分)添加到 splits 的开头。
            splits = [_splits[0]] + splits
        else:
            splits = re.split(separator, text)
    else:
        splits = list(text)
	# 过滤掉分割结果中的空字符串,返回最终的分割结果列表。
    return [s for s in splits if s != ""]

分割后的文本块合并成中等大小的块

分隔符可能将文本切分得非常稀碎,_merge_splits会将小的块合并到接近预设的_chunk_size长度。

就是一个简单的指针方法,current_doc用于存储当前指针攒的文档数量,total用于计算已经攒的文档的长度。长度够了就合并。

    def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
        # We now want to combine these smaller pieces into medium size
        # chunks to send to the LLM.
        separator_len = self._length_function(separator)

        docs = []
        current_doc: List[str] = []
        total = 0
        for d in splits:
            _len = self._length_function(d)
            if (
                total + _len + (separator_len if len(current_doc) > 0 else 0)
                > self._chunk_size
            ):
                if total > self._chunk_size:
                    logger.warning(
                        f"Created a chunk of size {total}, "
                        f"which is longer than the specified {self._chunk_size}"
                    )
                if len(current_doc) > 0:
                    doc = self._join_docs(current_doc, separator)
                    if doc is not None:
                        docs.append(doc)
					# 从current_doc开始删除文档块,保留小于_chunk_overlap的长度即可
                    # Keep on popping if:
                    # - we have a larger chunk than in the chunk overlap
                    # - or if we still have any chunks and the length is long
                    while total > self._chunk_overlap or (
                        total + _len + (separator_len if len(current_doc) > 0 else 0)
                        > self._chunk_size
                        and total > 0
                    ):
                        total -= self._length_function(current_doc[0]) + (
                            separator_len if len(current_doc) > 1 else 0
                        )
                        current_doc = current_doc[1:]
            current_doc.append(d)
            total += _len + (separator_len if len(current_doc) > 1 else 0)
        doc = self._join_docs(current_doc, separator)
        if doc is not None:
            docs.append(doc)
        return docs