Explorar o código

support split html to chunk (#4282)

changdazhou hai 4 meses
pai
achega
0510f23820

+ 100 - 148
paddlex/inference/pipelines/pp_doctranslation/pipeline.py

@@ -24,10 +24,20 @@ from ...utils.hpi import HPIConfig
 from ...utils.pp_option import PaddlePredictorOption
 from ..base import BasePipeline
 from .result import MarkdownResult
+from .utils import (
+    split_original_texts,
+    split_text_recursive,
+    translate_code_block,
+    translate_html_block,
+)
 
 
 @pipeline_requires_extra("trans")
 class PP_DocTranslation_Pipeline(BasePipeline):
+    """
+    PP_ DocTranslation_Pipeline
+    """
+
     entities = ["PP-DocTranslation"]
 
     def __init__(
@@ -246,125 +256,62 @@ class PP_DocTranslation_Pipeline(BasePipeline):
             markdown_info_list.append(MarkdownResult(markdown_info))
         return markdown_info_list
 
-    def split_markdown(self, md_text, chunk_size):
-        from bs4 import BeautifulSoup
-
-        if (
-            not isinstance(md_text, str)
-            or not isinstance(chunk_size, int)
-            or chunk_size <= 0
-        ):
-            raise ValueError("Invalid input parameters.")
-
-        chunks = []
-        current_chunk = []
-
-        # 如果整体文本小于chunk_size,直接返回
-        if len(md_text) < chunk_size:
-            return [md_text]
-
-        # 段落分割,两个及以上换行符视为分段
-        paragraphs = re.split(r"\n{2,}", md_text)
-
-        def split_table_to_chunks(table_html):
-            # 使用 BeautifulSoup 解析表格
-            soup = BeautifulSoup(table_html, "html.parser")
-            table = soup.find("table")
-
-            if not table:
-                return [table_html]  # 如果没有找到表格,直接返回原始内容
-
-            # 提取所有<tr>行
-            trs = table.find_all("tr")
-
-            # 按行累加,确保每个chunk长度<=chunk_size,且不破坏<tr>的完整性
-            table_chunks = []
-            current_rows = []
-            current_len = len("<table></table>")  # 基础长度
-
-            for tr in trs:
-                tr_str = str(tr)
-                row_len = len(tr_str)
-                if current_rows and current_len + row_len > chunk_size:
-                    # 打包当前chunk
-                    content = "<table>" + "".join(current_rows) + "</table>"
-                    table_chunks.append(content)
-                    current_rows = []  # 重置当前行列表
-                    current_len = len("<table></table>") + row_len
-
-                current_rows.append(tr_str)
-                current_len += row_len
-
-            if current_rows:
-                content = "<table>" + "".join(current_rows) + "</table>"
-                table_chunks.append(content)
-
-            return table_chunks
+    def chunk_translate(self, md_blocks, chunk_size, translate_func):
+        """
+        Chunks the given markdown blocks into smaller chunks of size `chunk_size` and translates them using the given
+        translate function.
 
-        # 句子分割,英文句号需区分小数点
-        sentence_pattern = re.compile(
-            r"(?<=[。!?!?])|(?<=\.)\s+(?=[A-Z])|(?<=\.)\s*$"
-        )
+        Args:
+            md_blocks (list): A list of tuples representing each block of markdown content. Each tuple consists of a string
+          indicating the block type ('text', 'code') and the actual content of the block.
+            chunk_size (int): The maximum size of each chunk.
+            translate_func (callable): A callable that accepts a string argument and returns the translated version of that string.
 
-        for paragraph in paragraphs:
-            paragraph = paragraph.strip()
-            if not paragraph:
-                continue
-
-            # 使用 BeautifulSoup 检查是否为完整表格
-            soup = BeautifulSoup(paragraph, "html.parser")
-            table = soup.find("table")
-
-            if table:
-                table_html = str(table)
-                if len(table_html) <= chunk_size:
-                    if current_chunk:
-                        chunks.append("\n\n".join(current_chunk))
-                        current_chunk = []
-                    chunks.append(table_html)
+        Returns:
+            str: A string containing all the translated chunks concatenated together with newlines between them.
+        """
+        translation_results = []
+        chunk = ""
+        logging.info(f"Split the original text into {len(md_blocks)} blocks")
+        logging.info("Starting translation...")
+        for idx, block in enumerate(md_blocks):
+            block_type, block_content = block
+            if block_type == "code":
+                if chunk.strip():
+                    translation_results.append(translate_func(chunk.strip()))
+                    chunk = ""  # Clear the chunk
+                logging.info(f"Translating block {idx+1}/{len(md_blocks)}...")
+                translate_code_block(
+                    block_content, chunk_size, translate_func, translation_results
+                )
+            elif len(block_content) < chunk_size:
+                if len(chunk) + len(block_content) < chunk_size:
+                    chunk += "\n\n" + block_content
                 else:
-                    # 表格太大,行分段
-                    if current_chunk:
-                        chunks.append("\n\n".join(current_chunk))
-                        current_chunk = []
-                    table_chunks = split_table_to_chunks(table_html)
-                    chunks.extend(table_chunks)
-                continue
-
-            # 普通文本处理
-            if sum(len(s) for s in current_chunk) + len(paragraph) <= chunk_size:
-                current_chunk.append(paragraph)
-            elif len(paragraph) <= chunk_size:
-                if current_chunk:
-                    chunks.append("\n\n".join(current_chunk))
-                current_chunk = [paragraph]
+                    if chunk.strip():
+                        logging.info(f"Translating block {idx+1}/{len(md_blocks)}...")
+                        translation_results.append(translate_func(chunk.strip()))
+                    chunk = block_content
             else:
-                # 段落太长,按句子切分
-                sentences = [
-                    s for s in sentence_pattern.split(paragraph) if s and s.strip()
-                ]
-                for sentence in sentences:
-                    sentence = sentence.strip()
-                    if not sentence:
-                        continue
-                    if len(sentence) > chunk_size:
-                        print(sentence)
-                        raise ValueError("A sentence exceeds the chunk size limit.")
-                    if sum(len(s) for s in current_chunk) + len(sentence) > chunk_size:
-                        if current_chunk:
-                            chunks.append("\n\n".join(current_chunk))
-                        current_chunk = [sentence]
-                    else:
-                        current_chunk.append(sentence)
-
-            if sum(len(s) for s in current_chunk) >= chunk_size:
-                chunks.append("\n\n".join(current_chunk))
-                current_chunk = []
-
-        if current_chunk:
-            chunks.append("\n\n".join(current_chunk))
-
-        return [c for c in chunks if c.strip()]
+                logging.info(f"Translating block {idx+1}/{len(md_blocks)}...")
+                if chunk.strip():
+                    translation_results.append(translate_func(chunk.strip()))
+                    chunk = ""  # Clear the chunk
+
+                if block_type == "text":
+                    split_text_recursive(
+                        block_content, chunk_size, translate_func, translation_results
+                    )
+                elif block_type == "text_with_html" or block_type == "html":
+                    translate_html_block(
+                        block_content, chunk_size, translate_func, translation_results
+                    )
+                else:
+                    raise ValueError(f"Unknown block type: {block_type}")
+
+        if chunk.strip():
+            translation_results.append(translate_func(chunk.strip()))
+        return "\n\n".join(translation_results)
 
     def translate(
         self,
@@ -383,12 +330,19 @@ class PP_DocTranslation_Pipeline(BasePipeline):
         Translate the given original text into the specified target language using the configured translation model.
 
         Args:
-            original_text (str): The original text to be translated.
-            target_language (str): The desired target language code.
+            ori_md_info_list (List[Dict]): A list of dictionaries containing information about the original markdown text to be translated.
+            target_language (str, optional): The desired target language code. Defaults to "zh".
+            chunk_size (int, optional): The maximum number of characters allowed per chunk when splitting long texts. Defaults to 5000.
+            task_description (str, optional): A description of the task being performed by the translation model. Defaults to None.
+            output_format (str, optional): The desired output format of the translation result. Defaults to None.
+            rules_str (str, optional): Rules or guidelines for the translation model to follow. Defaults to None.
+            few_shot_demo_text_content (str, optional): Demo text content for the translation model. Defaults to None.
+            few_shot_demo_key_value_list (str, optional): Demo text key-value list for the translation model. Defaults to None.
+            chat_bot_config (Any, optional): Configuration for the chat bot used in the translation process. Defaults to None.
             **kwargs: Additional keyword arguments passed to the translation model.
 
-        Returns:
-            str: The translated text in the target language.
+        Yields:
+            MarkdownResult: A dictionary containing the translation result in the target language.
         """
         if self.chat_bot is None:
             logging.warning(
@@ -410,39 +364,37 @@ class PP_DocTranslation_Pipeline(BasePipeline):
             # for multi page pdf
             ori_md_info_list = [self.concatenate_markdown_pages(ori_md_info_list)]
 
+        def translate_func(text):
+            """
+            Translate the given text using the configured translation model.
+
+            Args:
+                text (str): The text to be translated.
+
+            Returns:
+                str: The translated text in the target language.
+            """
+            prompt = self.translate_pe.generate_prompt(
+                original_text=text,
+                language=target_language,
+                task_description=task_description,
+                output_format=output_format,
+                rules_str=rules_str,
+                few_shot_demo_text_content=few_shot_demo_text_content,
+                few_shot_demo_key_value_list=few_shot_demo_key_value_list,
+            )
+            translate = chat_bot.generate_chat_results(prompt=prompt).get("content", "")
+            if translate is None:
+                raise Exception("The call to the large model failed.")
+            return translate
+
         for ori_md in ori_md_info_list:
 
             original_texts = ori_md["markdown_texts"]
-            chunks = self.split_markdown(original_texts, chunk_size)
-
-            target_language_chunks = []
-
-            if len(chunks) > 1:
-                logging.info(
-                    f"Get the markdown text, it's length is {len(original_texts)}, will split it into {len(chunks)} parts."
-                )
-
-            logging.info(
-                "Starting to translate the markdown text, will take a while. please wait..."
+            md_blocks = split_original_texts(original_texts)
+            target_language_texts = self.chunk_translate(
+                md_blocks, chunk_size, translate_func
             )
-            for idx, chunk in enumerate(chunks):
-                logging.info(f"Translating the {idx+1}/{len(chunks)} part.")
-                prompt = self.translate_pe.generate_prompt(
-                    original_text=chunk,
-                    language=target_language,
-                    task_description=task_description,
-                    output_format=output_format,
-                    rules_str=rules_str,
-                    few_shot_demo_text_content=few_shot_demo_text_content,
-                    few_shot_demo_key_value_list=few_shot_demo_key_value_list,
-                )
-                target_language_chunk = chat_bot.generate_chat_results(
-                    prompt=prompt
-                ).get("content", "")
-
-                target_language_chunks.append(target_language_chunk)
-
-            target_language_texts = "\n\n".join(target_language_chunks)
 
             yield MarkdownResult(
                 {

+ 331 - 0
paddlex/inference/pipelines/pp_doctranslation/utils.py

@@ -0,0 +1,331 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+
+
+def _find_split_pos(text, chunk_size):
+    """
+    Find the position to split the text into two chunks.
+
+    Args:
+        text (str): The original text to be split.
+        chunk_size (int): The maximum size of each chunk.
+
+    Returns:
+        int: The index where the text should be split.
+    """
+    center = len(text) // 2
+    # Search forward
+    for i in range(center, len(text)):
+        if text[i] in ["\n", ".", "。", ";", ";", "!", "!", "?", "?"]:
+            if i + 1 < len(text) and len(text[: i + 1]) <= chunk_size:
+                return i + 1
+    # Search backward
+    for i in range(center, 0, -1):
+        if text[i] in ["\n", ".", "。", ";", ";", "!", "!", "?", "?"]:
+            if len(text[: i + 1]) <= chunk_size:
+                return i + 1
+    # If no suitable position is found, split directly
+    return min(chunk_size, len(text))
+
+
+def split_text_recursive(text, chunk_size, translate_func, results):
+    """
+    Split the text recursively and translate each chunk.
+
+    Args:
+        text (str): The original text to be split.
+        chunk_size (int): The maximum size of each chunk.
+        translate_func (callable): A function that translates a single chunk of text.
+        results (list): A list to store the translated chunks.
+
+    Returns:
+        None
+    """
+    text = text.strip()
+    if len(text) <= chunk_size:
+        results.append(translate_func(text))
+    else:
+        split_pos = _find_split_pos(text, chunk_size)
+        left = text[:split_pos].strip()
+        right = text[split_pos:].strip()
+        if left:
+            split_text_recursive(left, chunk_size, translate_func, results)
+        if right:
+            split_text_recursive(right, chunk_size, translate_func, results)
+
+
+def translate_code_block(code_block, chunk_size, translate_func, results):
+    """
+    Translate a code block and append the result to the results list.
+
+    Args:
+        code_block (str): The code block to be translated.
+        chunk_size (int): The maximum size of each chunk.
+        translate_func (callable): A function that translates a single chunk of text.
+        results (list): A list to store the translated chunks.
+
+    Returns:
+        None
+    """
+    lines = code_block.strip().split("\n")
+    if lines[0].startswith("```") or lines[0].startswith("~~~"):
+        header = lines[0]
+        footer = (
+            lines[-1]
+            if (lines[-1].startswith("```") or lines[-1].startswith("~~~"))
+            else ""
+        )
+        code_content = "\n".join(lines[1:-1]) if footer else "\n".join(lines[1:])
+    else:
+        header = ""
+        footer = ""
+        code_content = code_block
+
+    translated_code_lines = []
+    split_text_recursive(
+        code_content, chunk_size, translate_func, translated_code_lines
+    )
+
+    # drop ``` or ~~~
+    filtered_code_lines = [
+        line
+        for line in translated_code_lines
+        if not (line.strip().startswith("```") or line.strip().startswith("~~~"))
+    ]
+    translated_code = "\n".join(filtered_code_lines)
+
+    result = f"{header}\n{translated_code}\n{footer}" if header else translated_code
+    results.append(result)
+
+
+def translate_html_block(html_block, chunk_size, translate_func, results):
+    """
+    Translate a HTML block and append the result to the results list.
+
+    Args:
+        html_block (str): The HTML block to be translated.
+        chunk_size (int): The maximum size of each chunk.
+        translate_func (callable): A function that translates a single chunk of text.
+        results (list): A list to store the translated chunks.
+
+    Returns:
+        None
+    """
+    from bs4 import BeautifulSoup
+
+    soup = BeautifulSoup(html_block, "html.parser")
+
+    # collect text nodes
+    text_nodes = []
+    for node in soup.find_all(string=True, recursive=True):
+        text = node.strip()
+        if text:
+            text_nodes.append(node)
+
+    idx = 0
+    total = len(text_nodes)
+    while idx < total:
+        batch_nodes = []
+        li_texts = []
+        current_length = len("<ol></ol>")
+        while idx < total:
+            node_text = text_nodes[idx].strip()
+            if len(node_text) > chunk_size:
+                # if node_text is too long, split it
+                translated_lines = []
+                split_text_recursive(
+                    node_text, chunk_size, translate_func, translated_lines
+                )
+                # concatenate translated lines with \n
+                text_nodes[idx].replace_with("\n".join(translated_lines))
+                idx += 1
+                continue
+            li_str = f"<li>{node_text}</li>"
+            if current_length + len(li_str) > chunk_size:
+                break
+            batch_nodes.append(text_nodes[idx])
+            li_texts.append(li_str)
+            current_length += len(li_str)
+            idx += 1
+        if not batch_nodes:
+            # if all individual nodes are longer than chunk_size, translate it alone
+            node_text = text_nodes[idx - 1].strip()
+            li_str = f"<li>{node_text}</li>"
+            batch_nodes = [text_nodes[idx - 1]]
+            li_texts = [li_str]
+
+        if batch_nodes:
+            batch_text = "<ol>" + "".join(li_texts) + "</ol>"
+            translated = translate_func(batch_text)
+            trans_soup = BeautifulSoup(translated, "html.parser")
+            translated_lis = trans_soup.find_all("li")
+            for orig_node, li_tag in zip(batch_nodes, translated_lis):
+                orig_node.replace_with(li_tag.decode_contents())
+
+    results.append(str(soup))
+
+
+def split_original_texts(text):
+    """
+    Split the original text into chunks.
+
+    Args:
+        text (str): The original text to be split.
+
+    Returns:
+        list: A list of strings representing the chunks of the original text.
+    """
+    from bs4 import BeautifulSoup
+
+    soup = BeautifulSoup(text, "html.parser")
+    result = []
+    last_position = 0
+    contents = soup.contents
+    i = 0
+    while i < len(contents):
+        element = contents[i]
+        str_element = str(element)
+        if len(str_element) == 0:
+            i += 1
+            continue
+
+        # find element in original text
+        start = text.find(str_element, last_position)
+        if start != -1:
+            end = start + len(str_element)
+            element_str = text[start:end]
+        else:
+            # if element is not a tag, try to find it in original text
+            if hasattr(element, "name") and element.name is not None:
+                tag = element.name
+                pat = r"<{tag}.*?>.*?</{tag}>".format(tag=tag)
+                re_html = re.compile(pat, re.DOTALL)
+                match = re_html.search(text, last_position)
+                if match:
+                    start = match.start()
+                    end = match.end()
+                    element_str = text[start:end]
+                else:
+                    element_str = str_element
+                    start = -1
+                    end = -1
+            else:
+                element_str = str_element
+                start = -1
+                end = -1
+
+        true_start = True
+        if start > 0 and text[start - 1] != "\n":
+            true_start = False
+
+        # process previous text
+        if start != -1 and last_position < start:
+            text_content = text[last_position:start]
+            result = split_and_append_text(result, text_content)
+
+        if hasattr(element, "name") and element.name is not None:
+            if (
+                end < len(text)
+                and end >= 0
+                and (text[end] not in ["\n", " "] or element_str.endswith("\n"))
+            ):
+                next_block_pos = text.find("\n\n", end)
+                if next_block_pos == -1:
+                    mix_region_end = len(text)
+                else:
+                    mix_region_end = next_block_pos + 2
+
+                j = i + 1
+                while j < len(contents):
+                    next_element_str = str(contents[j])
+                    next_start = text.find(next_element_str, end)
+                    if next_start == -1 or next_start >= mix_region_end:
+                        break
+                    j += 1
+                if true_start:
+                    # merge text and html
+                    result.append(
+                        ("text_with_html", text[start:mix_region_end].rstrip("\n"))
+                    )
+                else:
+                    _, last_content = result[-1]
+                    result.pop()
+                    result.append(
+                        (
+                            "text_with_html",
+                            last_content + text[start:mix_region_end].rstrip("\n"),
+                        )
+                    )
+                last_position = mix_region_end
+                i = j
+            else:
+                # pure HTML block
+                if true_start:
+                    result.append(("html", element_str))
+                else:
+                    _, last_content = result[-1]
+                    result.pop()
+                    result.append(("html", last_content + element_str))
+                last_position = end
+                i += 1
+        else:
+            # normal text
+            result = split_and_append_text(result, element_str)
+            last_position = end if end != -1 else last_position + len(element_str)
+            i += 1
+
+    # process remaining text
+    if last_position < len(text):
+        text_content = text[last_position:]
+        result = split_and_append_text(result, text_content)
+
+    return result
+
+
+def split_and_append_text(result, text_content):
+    """
+    Split the text and append the result to the result list.
+
+    Args:
+        result (list): The current result list.
+        text_content (str): The text content to be processed.
+
+    Returns:
+        list: The updated result list after processing the text content.
+    """
+    if text_content.strip():
+        # match all code block interval
+        code_pattern = re.compile(r"(```.*?\n.*?```|~~~.*?\n.*?~~~)", re.DOTALL)
+        last_pos = 0
+        for m in code_pattern.finditer(text_content):
+            # process text before code block
+            if m.start() > last_pos:
+                non_code = text_content[last_pos : m.start()]
+                paragraphs = re.split(r"\n{2,}", non_code)
+                for p in paragraphs:
+                    if p.strip():
+                        result.append(("text", p.strip()))
+            # process code block
+            result.append(("code", m.group()))
+            last_pos = m.end()
+        # process remaining text
+        if last_pos < len(text_content):
+            non_code = text_content[last_pos:]
+            paragraphs = re.split(r"\n{2,}", non_code)
+            for p in paragraphs:
+                if p.strip():
+                    result.append(("text", p.strip()))
+    return result