一、TextSplitter
TextSplitter繼承自BaseDocumentTransformer,是一個抽象類,不能直接創建實例。
核心(內部)屬性有:
_chunk_size: 每塊大小
_chunk_overlap: 每塊之間的重疊區大小
_length_function: 計算大小的方法,可以傳遞token計算的函數,也可以傳別的比如普通的len()
_keep_separator: Boolean 分塊後是否保留分割符
_add_start_index: Boolean 是否在分割後返回的文檔元數據中保存每塊第一個字符在原始文檔中的index
_strip_whitespace: Boolean 分割後是否去掉前後的空格
核心方法:
split_text(self, text: str) -> List(str)
分割方法,抽象方法,要在具體的子類中根據分割算法實現。
create_documents(self, texts: list[str], metadatas: list[dict]) -> list[Document]
傳入文本和可選的元數據信息,返回將文本調用split_text分割後,創建的Document格式數據,doc.page_content是文本,metadata是創建的元數據,根據是否_add_start_index自動保存index
split_documents(self, documents: Iterable[Document]) -> list[Document]
將傳入的document列表分割,返回分割後的document列表,內部就是對每個document調用create_documents創建文檔,組合返回。
--------以下為內部方法---------
_join_docs(self, docs: list[str], separator: str) -> str
注意這個參數裏的docs是字符串列表,就是根據傳入的分割符合並字符串列表為一個長字符串,給下面的_merge_splits使用
_merge_splits(self, splits: Iterable[str], separator: str) -> list[str]
把分割得過於細的小塊合併成更接近self._chunk_size的塊,並確保相鄰塊之間有self._chunk_overlap大小的重疊內容。
def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]: 2 # We now want to combine these smaller pieces into medium size 3 # chunks to send to the LLM. 4 separator_len = self._length_function(separator) 5 6 docs = [] 7 current_doc: list[str] = [] 8 total = 0 9 for d in splits: 10 len_ = self._length_function(d) # 默認先在current_doc裏面append(d),直到滿足下面的if,往docs裏面加入值 11 if ( 12 total + len_ + (separator_len if len(current_doc) > 0 else 0) 13 > self._chunk_size 14 ): 15 if total > self._chunk_size: 16 logger.warning( 17 "Created a chunk of size %d, which is longer than the " 18 "specified %d", 19 total, 20 self._chunk_size, 21 ) 22 if len(current_doc) > 0: 23 doc = self._join_docs(current_doc, separator) 24 if doc is not None: 25 docs.append(doc) 26 # Keep on popping if: 27 # - we have a larger chunk than in the chunk overlap 28 # - or if we still have any chunks and the length is long 29 while total > self._chunk_overlap or ( 30 total + len_ + (separator_len if len(current_doc) > 0 else 0) 31 > self._chunk_size 32 and total > 0 33 ): 34 total -= self._length_function(current_doc[0]) + ( 35 separator_len if len(current_doc) > 1 else 0 36 ) 37 current_doc = current_doc[1:] 38 current_doc.append(d) 39 total += len_ + (separator_len if len(current_doc) > 1 else 0) 40 doc = self._join_docs(current_doc, separator) 41 if doc is not None: 42 docs.append(doc) 43 return docs
這個方法的核心是,每當current_doc滿足chunk_size時,先把current_chunk裏面的字符join後塞進docs,然後,不是直接清空curent_chunk,而是依次從current_chunk頭部移除文本單元,直到current_chunk的文本長度小於_chunk_overlap。此時current_chunk裏面的文本就是新塊的開頭,也是兩塊之間的重疊值。
二、CharacterTextSplitter
這個類繼承自上面的TextSplitter,增加了separator屬性和is_separator_regex(分割符是否為正則表達式)屬性。實現了父類的抽象方法split_text。
這個類裏的split_text方法調用了自定義的_split_text_with_regex()方法,對傳入的文本text進行分割。先看代碼:
1 # CharactorTextSplitter類內部 2 def split_text(self, text: str) -> list[str]: 3 """Split into chunks without re-inserting lookaround separators.""" 4 # 1. Determine split pattern: raw regex or escaped literal 5 sep_pattern = ( 6 self._separator if self._is_separator_regex else re.escape(self._separator) 7 ) 8 9 # 2. Initial split (keep separator if requested) 10 splits = _split_text_with_regex( 11 text, sep_pattern, keep_separator=self._keep_separator 12 ) 13 14 # 3. Detect zero-width lookaround so we never re-insert it 15 lookaround_prefixes = ("(?=", "(?<!", "(?<=", "(?!") 16 is_lookaround = self._is_separator_regex and any( 17 self._separator.startswith(p) for p in lookaround_prefixes 18 ) 19 20 # 4. Decide merge separator: 21 # - if keep_separator or lookaround -> don't re-insert 22 # - else -> re-insert literal separator 23 merge_sep = "" 24 if not (self._keep_separator or is_lookaround): 25 merge_sep = self._separator 26 27 # 5. Merge adjacent splits and return 28 return self._merge_splits(splits, merge_sep) 29 30 # 外部方法 31 def _split_text_with_regex( 32 text: str, separator: str, *, keep_separator: bool | Literal["start", "end"] 33 ) -> list[str]: 34 # Now that we have the separator, split the text 35 if separator: 36 if keep_separator: 37 # The parentheses in the pattern keep the delimiters in the result. 38 splits_ = re.split(f"({separator})", text) 39 splits = ( 40 ([splits_[i] + splits_[i + 1] for i in range(0, len(splits_) - 1, 2)]) 41 if keep_separator == "end" 42 else ([splits_[i] + splits_[i + 1] for i in range(1, len(splits_), 2)]) 43 ) 44 if len(splits_) % 2 == 0: 45 splits += splits_[-1:] 46 splits = ( 47 ([*splits, splits_[-1]]) 48 if keep_separator == "end" 49 else ([splits_[0], *splits]) 50 ) 51 else: 52 splits = re.split(separator, text) 53 else: 54 splits = list(text) 55 return [s for s in splits if s]
如果不考慮保留分割符,其實這個方法很簡單,就是使用re.split將傳入text用分割符分開後,再調用父類實現的_merge_splits()拼接成合適大小的塊,返回list[str]。
1. 分割前處理
如果傳入的分割符是一個字符串,調用re.split前,需要將字符串轉義一下,防止有不合法的字符。
# 1. Determine split pattern: raw regex or escaped literal sep_pattern = ( self._separator if self._is_separator_regex else re.escape(self._separator) )
1 splits = ( 2 ([splits_[i] + splits_[i + 1] for i in range(0, len(splits_) - 1, 2)]) 3 if keep_separator == "end" 4 else ([splits_[i] + splits_[i + 1] for i in range(1, len(splits_), 2)]) 5 ) 6 if len(splits_) % 2 == 0: 7 splits += splits_[-1:] 8 splits = ( 9 ([*splits, splits_[-1]]) 10 if keep_separator == "end" 11 else ([splits_[0], *splits]) 12 )
1 merge_sep = "" 2 if not (self._keep_separator or is_lookaround): 3 merge_sep = self._separator