Langchain笔记:Textsplitter
RecursiveCharacterTextSplitter
一直好奇所谓的递归分割是怎么个递归法,打开源码一探究竟。
源码位置:langchain_text_splitters.character.RecursiveCharacterTextSplitter
类定义和初始化
class RecursiveCharacterTextSplitter(TextSplitter):
"""Splitting text by recursively look at characters.
Recursively tries to split by different characters to find one
that works.
"""
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = False,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex
__init__
方法用于初始化对象,接受以下参数:
separators
: 一个可选的字符串列表,用于指定分割文本的分隔符。默认值为["\n\n", "\n", " ", ""]
。keep_separator
: 一个布尔值,表示是否保留分隔符。默认值为True
。is_separator_regex
: 一个布尔值,表示分隔符是否为正则表达式。默认值为False
。**kwargs
: 其他关键字参数,传递给父类的初始化方法。
文本分割方法
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# 从分隔符列表中遍历出第一个可以(即文本段中要能搜到这个分隔符)用于分割文本段的分隔符。
# 记录剩余的分隔符为 new_separators。
# 即:先尝试用"\n\n"分割,再"\n",再" ",最后""兜底就是不分了。
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1 :]
break
# 将正则分隔符和非正则分隔符统一为正则表达式的形式,然后用正则的方法对文档进行分割。
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex(text, _separator, self._keep_separator)
# 初始化 _good_splits 用于存储长度小于 _chunk_size 的文本块。
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
# 如果文本块长度大于或等于 _chunk_size,则合并 _good_splits 中的文本块,并清空 _good_splits。
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
# 如果没有剩余的分隔符 (new_separators 为空),则直接将当前文本块添加到 final_chunks。
if not new_separators:
final_chunks.append(s)
# 如果还有剩余的分隔符,递归调用 _split_text 方法继续分割当前文本块。
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks
所以,所谓递归分割的方法就是:按照分隔符的优先级进行分割。当使用高优先级分割出来的某个文档块超过指定长度,会对这个超长的文本块递归调用低优先级的分隔符继续进行分割。
other_info = self._split_text(s, new_separators)
使用正则表达式根据指定的分隔符分割文本
_split_text_with_regex
函数通过正则表达式根据指定的分隔符分割文本,并根据 keep_separator
参数决定是否保留分隔符。该函数在处理过程中会合并文本块和分隔符(如果需要保留分隔符),并过滤掉空字符串,最终返回一个非空的分割结果列表。
def _split_text_with_regex(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
# 通过列表推导式 [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] 将文本块和分隔符合并。
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
# 如果 _splits 的长度为偶数,说明最后一个元素是单独的文本块,需要添加到 splits 中。
if len(_splits) % 2 == 0:
splits += _splits[-1:]
# 将第一个元素(原始文本的第一个部分)添加到 splits 的开头。
splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
# 过滤掉分割结果中的空字符串,返回最终的分割结果列表。
return [s for s in splits if s != ""]
分割后的文本块合并成中等大小的块
分隔符可能将文本切分得非常稀碎,_merge_splits
会将小的块合并到接近预设的_chunk_size
长度。
就是一个简单的指针方法,current_doc用于存储当前指针攒的文档数量,total用于计算已经攒的文档的长度。长度够了就合并。
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
):
if total > self._chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# 从current_doc开始删除文档块,保留小于_chunk_overlap的长度即可
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs