Ver Fonte

split block by '\n\n' & update doc (#4286)

changdazhou há 4 meses atrás
pai
commit
b102f238dd

+ 1 - 1
api_examples/pipelines/test_pp_doctranslation.py

@@ -56,7 +56,7 @@ else:
 tgt_md_info_list = pipeline.translate(
     ori_md_info_list=ori_md_info_list,
     target_language="en",
-    chunk_size=5000,
+    chunk_size=3000,
     chat_bot_config=chat_bot_config,
 )
 for tgt_md_info in tgt_md_info_list:

+ 9 - 7
docs/pipeline_usage/tutorials/ocr_pipelines/PP-DocTranslation.md

@@ -7,6 +7,8 @@ comments: true
 ## 1. PP-DocTranslation产线介绍
 通用文档翻译产线(PP-DocTranslation)是飞桨提供的文档智能翻译解决方案,融合了先进的通用版面解析技术与大语言模型(LLM)能力,为您提供高效的文档智能翻译服务。该解决方案能够精准识别并提取文档中的各类元素,包括文本块、标题、段落、图片、表格及其他复杂版面结构,并在此基础之上实现高质量的多语种互译。PP-DocTranslation 支持多种主流语言间的相互翻译,尤其擅长处理排版复杂、上下文依赖性强的文档场景,力求输出精准自然、流畅专业的翻译结果。本产线同时提供了灵活的服务化部署方式,支持在多种硬件上使用多种编程语言调用。不仅如此,本产线也提供了二次开发的能力,您可以基于本产线在您自己的数据集上训练调优,训练后的模型也可以无缝集成。
 
+<img src="https://raw.githubusercontent.com/cuicheng01/PaddleX_doc_images/main/images/pipelines/doc_translation/pp_doctranslation.png">
+
 
 <b>通用文档翻译产线中使用了通用版面解析v3子产线,因此具有通用版面解析v3产线的所有功能,更多关于通用版面解析v3产线的功能介绍和使用细节,可以点击 [通用版面解析v3产线文档](./PP-StructureV3.md) 页面查看</b>。
 
@@ -1441,6 +1443,13 @@ for tgt_md_info in tgt_md_info_list:
 <td><code>None</code></td>
 </tr>
 <tr>
+<td><code>llm_request_interval</code></td>
+<td>向大语言模型发送请求的时间间隔,单位为秒。该参数可用于防止过于频繁地调用大语言模型。</td>
+<td><code>float</code></td>
+<td>大于等于0的浮点数</td>
+<td><code>0</code></td>
+</tr>
+<tr>
 <td><code>chat_bot_config</code></td>
 <td>大语言模型配置</td>
 <td><code>dict|None</code></td>
@@ -1452,13 +1461,6 @@ for tgt_md_info in tgt_md_info_list:
 </td>
 <td><code>None</code></td>
 </tr>
-<tr>
-<td><code>llm_request_interval</code></td>
-<td>向大语言模型发送请求的时间间隔,单位为秒。该参数可用于防止过于频繁地调用大语言模型。</td>
-<td><code>float</code></td>
-<td>大于等于0的浮点数</td>
-<td><code>0</code></td>
-</tr>
 </tbody>
 </table>
 

+ 8 - 0
paddlex/inference/pipelines/components/prompt_engineering/generate_translate_prompt.py

@@ -154,9 +154,17 @@ class GenerateTranslatePrompt(BaseGeneratePrompt):
         if few_shot_demo_text_content is None:
             few_shot_demo_text_content = self.few_shot_demo_text_content
 
+        if few_shot_demo_text_content:
+            few_shot_demo_text_content = (
+                f"这里是一些示例:\n{few_shot_demo_text_content}\n"
+            )
+
         if few_shot_demo_key_value_list is None:
             few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
 
+        if few_shot_demo_key_value_list:
+            few_shot_demo_key_value_list = f"这里是一些专业术语对照表,对照表中单词要参考对照表翻译:\n{few_shot_demo_key_value_list}\n"
+
         prompt = f"""{task_description}{rules_str}{output_format}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
 
         language_name = language_map.get(language, language)

+ 29 - 5
paddlex/inference/pipelines/pp_doctranslation/pipeline.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import re
+from time import sleep
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 import numpy as np
@@ -310,7 +311,7 @@ class PP_DocTranslation_Pipeline(BasePipeline):
                 translate_code_block(
                     block_content, chunk_size, translate_func, translation_results
                 )
-            elif len(block_content) < chunk_size:
+            elif len(block_content) < chunk_size and block_type == "text":
                 if len(chunk) + len(block_content) < chunk_size:
                     chunk += "\n\n" + block_content
                 else:
@@ -343,14 +344,14 @@ class PP_DocTranslation_Pipeline(BasePipeline):
         self,
         ori_md_info_list: List[Dict],
         target_language: str = "zh",
-        chunk_size: int = 5000,
+        chunk_size: int = 3000,
         task_description: str = None,
         output_format: str = None,
         rules_str: str = None,
         few_shot_demo_text_content: str = None,
         few_shot_demo_key_value_list: str = None,
-        chat_bot_config=None,
-        llm_request_interval: float = 0,
+        llm_request_interval: float = 0.0,
+        chat_bot_config: Dict = None,
         **kwargs,
     ):
         """
@@ -365,7 +366,8 @@ class PP_DocTranslation_Pipeline(BasePipeline):
             rules_str (str, optional): Rules or guidelines for the translation model to follow. Defaults to None.
             few_shot_demo_text_content (str, optional): Demo text content for the translation model. Defaults to None.
             few_shot_demo_key_value_list (str, optional): Demo text key-value list for the translation model. Defaults to None.
-            chat_bot_config (Any, optional): Configuration for the chat bot used in the translation process. Defaults to None.
+            llm_request_interval (float, optional): The interval in seconds between each request to the LLM. Defaults to 0.0.
+            chat_bot_config (Dict, optional): Configuration for the chat bot used in the translation process. Defaults to None.
             **kwargs: Additional keyword arguments passed to the translation model.
 
         Yields:
@@ -391,6 +393,9 @@ class PP_DocTranslation_Pipeline(BasePipeline):
             # for multi page pdf
             ori_md_info_list = [self.concatenate_markdown_pages(ori_md_info_list)]
 
+        if not isinstance(llm_request_interval, float):
+            llm_request_interval = float(llm_request_interval)
+
         def translate_func(text):
             """
             Translate the given text using the configured translation model.
@@ -401,6 +406,7 @@ class PP_DocTranslation_Pipeline(BasePipeline):
             Returns:
                 str: The translated text in the target language.
             """
+            sleep(llm_request_interval)
             prompt = self.translate_pe.generate_prompt(
                 original_text=text,
                 language=target_language,
@@ -415,6 +421,24 @@ class PP_DocTranslation_Pipeline(BasePipeline):
                 raise Exception("The call to the large model failed.")
             return translate
 
+        base_prompt_content = self.translate_pe.generate_prompt(
+            original_text="",
+            language=target_language,
+            task_description=task_description,
+            output_format=output_format,
+            rules_str=rules_str,
+            few_shot_demo_text_content=few_shot_demo_text_content,
+            few_shot_demo_key_value_list=few_shot_demo_key_value_list,
+        )
+        base_prompt_length = len(base_prompt_content)
+
+        if chunk_size > base_prompt_length:
+            chunk_size = chunk_size - base_prompt_length
+        else:
+            raise ValueError(
+                f"Chunk size should be greater than the base prompt length ({base_prompt_length}), but got {chunk_size}."
+            )
+
         for ori_md in ori_md_info_list:
 
             original_texts = ori_md["markdown_texts"]

+ 36 - 107
paddlex/inference/pipelines/pp_doctranslation/utils.py

@@ -181,118 +181,47 @@ def translate_html_block(html_block, chunk_size, translate_func, results):
 def split_original_texts(text):
     """
     Split the original text into chunks.
-
-    Args:
-        text (str): The original text to be split.
-
-    Returns:
-        list: A list of strings representing the chunks of the original text.
     """
     from bs4 import BeautifulSoup
 
+    # find all html blocks and replace them with placeholders
     soup = BeautifulSoup(text, "html.parser")
-    result = []
-    last_position = 0
-    contents = soup.contents
-    i = 0
-    while i < len(contents):
-        element = contents[i]
-        str_element = str(element)
-        if len(str_element) == 0:
-            i += 1
-            continue
-
-        # find element in original text
-        start = text.find(str_element, last_position)
-        if start != -1:
-            end = start + len(str_element)
-            element_str = text[start:end]
+    html_blocks = []
+    html_placeholders = []
+    placeholder_fmt = "<<HTML_BLOCK_{}>>"
+    text_after_placeholder = ""
+
+    index = 0
+    for elem in soup.contents:
+        if hasattr(elem, "name") and elem.name is not None:
+            html_str = str(elem)
+            placeholder = placeholder_fmt.format(index)
+            html_blocks.append(html_str)
+            html_placeholders.append(placeholder)
+            text_after_placeholder += placeholder
+            index += 1
         else:
-            # if element is not a tag, try to find it in original text
-            if hasattr(element, "name") and element.name is not None:
-                tag = element.name
-                pat = r"<{tag}.*?>.*?</{tag}>".format(tag=tag)
-                re_html = re.compile(pat, re.DOTALL)
-                match = re_html.search(text, last_position)
-                if match:
-                    start = match.start()
-                    end = match.end()
-                    element_str = text[start:end]
-                else:
-                    element_str = str_element
-                    start = -1
-                    end = -1
-            else:
-                element_str = str_element
-                start = -1
-                end = -1
-
-        true_start = True
-        if start > 0 and text[start - 1] != "\n":
-            true_start = False
-
-        # process previous text
-        if start != -1 and last_position < start:
-            text_content = text[last_position:start]
-            result = split_and_append_text(result, text_content)
-
-        if hasattr(element, "name") and element.name is not None:
-            if (
-                end < len(text)
-                and end >= 0
-                and (text[end] not in ["\n", " "] or element_str.endswith("\n"))
-            ):
-                next_block_pos = text.find("\n\n", end)
-                if next_block_pos == -1:
-                    mix_region_end = len(text)
-                else:
-                    mix_region_end = next_block_pos + 2
-
-                j = i + 1
-                while j < len(contents):
-                    next_element_str = str(contents[j])
-                    next_start = text.find(next_element_str, end)
-                    if next_start == -1 or next_start >= mix_region_end:
-                        break
-                    j += 1
-                if true_start:
-                    # merge text and html
-                    result.append(
-                        ("text_with_html", text[start:mix_region_end].rstrip("\n"))
-                    )
-                else:
-                    _, last_content = result[-1]
-                    result.pop()
-                    result.append(
-                        (
-                            "text_with_html",
-                            last_content + text[start:mix_region_end].rstrip("\n"),
-                        )
-                    )
-                last_position = mix_region_end
-                i = j
-            else:
-                # pure HTML block
-                if true_start:
-                    result.append(("html", element_str))
-                else:
-                    _, last_content = result[-1]
-                    result.pop()
-                    result.append(("html", last_content + element_str))
-                last_position = end
-                i += 1
-        else:
-            # normal text
-            result = split_and_append_text(result, element_str)
-            last_position = end if end != -1 else last_position + len(element_str)
-            i += 1
-
-    # process remaining text
-    if last_position < len(text):
-        text_content = text[last_position:]
-        result = split_and_append_text(result, text_content)
-
-    return result
+            text_after_placeholder += str(elem)
+
+    # split text into paragraphs
+    splited_block = []
+    splited_block = split_and_append_text(splited_block, text_after_placeholder)
+
+    # replace placeholders with html blocks
+    current_index = 0
+    for idx, block in enumerate(splited_block):
+        _, content = block
+        while (
+            current_index < len(html_placeholders)
+            and html_placeholders[current_index] in content
+        ):
+            content = content.replace(
+                html_placeholders[current_index], html_blocks[current_index]
+            )
+            current_index += 1
+            splited_block[idx] = ("html", content)
+
+    return splited_block
 
 
 def split_and_append_text(result, text_content):