Parcourir la source

support glossary for translation

zhouchangda il y a 4 mois
Parent
commit
93ac70c84b

+ 13 - 1
docs/pipeline_usage/tutorials/ocr_pipelines/PP-DocTranslation.en.md

@@ -1386,13 +1386,25 @@ After executing the above code, you will obtain the parsed results of the origin
 <td><code>str|None</code></td>
 <td>
 <ul>
-<li><b>str</b>: Example data in key-value pair format, which can include a terminology对照表 (glossary)</li>
+<li><b>str</b>: Example data in key-value pair format</li>
 <li><b>None</b>: Do not provide structured examples</li>
 </ul>
 </td>
 <td><code>None</code></td>
 </tr>
 <tr>
+<td><code>glossary</code></td>
+<td>Glossary of technical terms</td>
+<td><code>dict|None</code></td>
+<td>
+<ul>
+<li><b>dict</b>: Dictionary for glossary mapping</li>
+<li><b>None</b>: Use default configuration</li>
+</ul>
+</td>
+<td><code>None</code></td>
+</tr>
+<tr>
 <td><code>llm_request_interval</code></td>
 <td>Time interval in seconds for sending requests to the large language model. This parameter can be used to prevent overly frequent calls to the large language model.</td>
 <td><code>float</code></td>

+ 12 - 0
docs/pipeline_usage/tutorials/ocr_pipelines/PP-DocTranslation.md

@@ -1448,6 +1448,18 @@ for tgt_md_info in tgt_md_info_list:
 <td><code>None</code></td>
 </tr>
 <tr>
+<td><code>glossary</code></td>
+<td>专业术语对照表</td>
+<td><code>dict|None</code></td>
+<td>
+<ul>
+<li><b>dict</b>:词表映射字典</li>
+<li><b>None</b>:使用默认配置</li>
+</ul>
+</td>
+<td><code>None</code></td>
+</tr>
+<tr>
 <td><code>llm_request_interval</code></td>
 <td>向大语言模型发送请求的时间间隔,单位为秒。该参数可用于防止过于频繁地调用大语言模型。</td>
 <td><code>float</code></td>

+ 1 - 1
paddlex/inference/pipelines/components/prompt_engineering/generate_translate_prompt.py

@@ -163,7 +163,7 @@ class GenerateTranslatePrompt(BaseGeneratePrompt):
             few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
 
         if few_shot_demo_key_value_list:
-            few_shot_demo_key_value_list = f"这里是一些专业术语对照表,对照表中单词要参考对照表翻译:\n{few_shot_demo_key_value_list}\n"
+            few_shot_demo_key_value_list = f"\n这里是一些专业术语对照表,如果遇到对照表中单词要参考对照表翻译:\n{few_shot_demo_key_value_list}\n"
 
         prompt = f"""{task_description}{rules_str}{output_format}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
 

+ 20 - 2
paddlex/inference/pipelines/pp_doctranslation/pipeline.py

@@ -326,8 +326,8 @@ class PP_DocTranslation_Pipeline(BasePipeline):
                     chunk = ""  # Clear the chunk
 
                 if block_type == "text":
-                    split_text_recursive(
-                        block_content, chunk_size, translate_func, translation_results
+                    translation_results.append(
+                        split_text_recursive(block_content, chunk_size, translate_func)
                     )
                 elif block_type == "text_with_html" or block_type == "html":
                     translate_html_block(
@@ -350,6 +350,7 @@ class PP_DocTranslation_Pipeline(BasePipeline):
         rules_str: str = None,
         few_shot_demo_text_content: str = None,
         few_shot_demo_key_value_list: str = None,
+        glossary: Dict = None,
         llm_request_interval: float = 0.0,
         chat_bot_config: Dict = None,
         **kwargs,
@@ -366,6 +367,7 @@ class PP_DocTranslation_Pipeline(BasePipeline):
             rules_str (str, optional): Rules or guidelines for the translation model to follow. Defaults to None.
             few_shot_demo_text_content (str, optional): Demo text content for the translation model. Defaults to None.
             few_shot_demo_key_value_list (str, optional): Demo text key-value list for the translation model. Defaults to None.
+            glossary (Dict, optional): A dictionary containing terms and their corresponding definitions. Defaults to None.
             llm_request_interval (float, optional): The interval in seconds between each request to the LLM. Defaults to 0.0.
             chat_bot_config (Dict, optional): Configuration for the chat bot used in the translation process. Defaults to None.
             **kwargs: Additional keyword arguments passed to the translation model.
@@ -396,6 +398,22 @@ class PP_DocTranslation_Pipeline(BasePipeline):
         if not isinstance(llm_request_interval, float):
             llm_request_interval = float(llm_request_interval)
 
+        assert isinstance(glossary, dict) or glossary is None, "glossary must be a dict"
+
+        glossary_str = ""
+        if glossary is not None:
+            for k, v in glossary.items():
+                if isinstance(v, list):
+                    v = "或".join(v)
+                glossary_str += f"{k}: {v}\n"
+
+        if glossary_str != "":
+            if few_shot_demo_key_value_list is None:
+                few_shot_demo_key_value_list = glossary_str
+            else:
+                few_shot_demo_key_value_list += "\n"
+                few_shot_demo_key_value_list += glossary_str
+
         def translate_func(text):
             """
             Translate the given text using the configured translation model.

+ 78 - 18
paddlex/inference/pipelines/pp_doctranslation/utils.py

@@ -15,6 +15,29 @@
 import re
 
 
+def _is_sentence_dot(text, i):
+    """
+    Check if the given character is a sentence ending punctuation.
+    """
+    # if the character is not a period, return False
+    if text[i] != ".":
+        return False
+    # previous character
+    prev = text[i - 1] if i > 0 else ""
+    # next character
+    next = text[i + 1] if i + 1 < len(text) else ""
+    # previous is digit or letter, then not sentence ending punctuation
+    if prev.isdigit() or prev.isalpha():
+        return False
+    # next is digit or letter, then not sentence ending punctuation
+    if next.isdigit() or next.isalpha():
+        return False
+    # next is a punctuation, then sentence ending punctuation
+    if next in ("", " ", "\t", "\n", '"', "'", "”", "’", ")", "】", "」", "》"):
+        return True
+    return False
+
+
 def _find_split_pos(text, chunk_size):
     """
     Find the position to split the text into two chunks.
@@ -27,21 +50,44 @@ def _find_split_pos(text, chunk_size):
         int: The index where the text should be split.
     """
     center = len(text) // 2
+    split_chars = ["\n", "。", ";", ";", "!", "!", "?", "?"]
+
     # Search forward
     for i in range(center, len(text)):
-        if text[i] in ["\n", ".", "。", ";", ";", "!", "!", "?", "?"]:
-            if i + 1 < len(text) and len(text[: i + 1]) <= chunk_size:
-                return i + 1
+        if text[i] in split_chars:
+            # Check for whitespace around the split character
+            j = i + 1
+            while j < len(text) and text[j] in " \t\n":
+                j += 1
+            if j < len(text) and len(text[:j]) <= chunk_size:
+                return i, j
+        elif text[i] == "." and _is_sentence_dot(text, i):
+            j = i + 1
+            while j < len(text) and text[j] in " \t\n":
+                j += 1
+            if j < len(text) and len(text[:j]) <= chunk_size:
+                return i, j
+
     # Search backward
     for i in range(center, 0, -1):
-        if text[i] in ["\n", ".", "。", ";", ";", "!", "!", "?", "?"]:
-            if len(text[: i + 1]) <= chunk_size:
-                return i + 1
+        if text[i] in split_chars:
+            j = i + 1
+            while j < len(text) and text[j] in " \t\n":
+                j += 1
+            if len(text[:j]) <= chunk_size:
+                return i, j
+        elif text[i] == "." and _is_sentence_dot(text, i):
+            j = i + 1
+            while j < len(text) and text[j] in " \t\n":
+                j += 1
+            if len(text[:j]) <= chunk_size:
+                return i, j
+
     # If no suitable position is found, split directly
-    return min(chunk_size, len(text))
+    return min(chunk_size, len(text)), min(chunk_size, len(text))
 
 
-def split_text_recursive(text, chunk_size, translate_func, results):
+def split_text_recursive(text, chunk_size, translate_func):
     """
     Split the text recursively and translate each chunk.
 
@@ -56,15 +102,19 @@ def split_text_recursive(text, chunk_size, translate_func, results):
     """
     text = text.strip()
     if len(text) <= chunk_size:
-        results.append(translate_func(text))
+        return translate_func(text)
     else:
-        split_pos = _find_split_pos(text, chunk_size)
-        left = text[:split_pos].strip()
-        right = text[split_pos:].strip()
+        split_pos, end_whitespace = _find_split_pos(text, chunk_size)
+        left = text[:split_pos]
+        right = text[end_whitespace:]
+        whitespace = text[split_pos:end_whitespace]
+
         if left:
-            split_text_recursive(left, chunk_size, translate_func, results)
+            left_text = split_text_recursive(left, chunk_size, translate_func)
         if right:
-            split_text_recursive(right, chunk_size, translate_func, results)
+            right_text = split_text_recursive(right, chunk_size, translate_func)
+
+        return left_text + whitespace + right_text
 
 
 def translate_code_block(code_block, chunk_size, translate_func, results):
@@ -94,15 +144,14 @@ def translate_code_block(code_block, chunk_size, translate_func, results):
         footer = ""
         code_content = code_block
 
-    translated_code_lines = []
-    split_text_recursive(
-        code_content, chunk_size, translate_func, translated_code_lines
+    translated_code_lines = split_text_recursive(
+        code_content, chunk_size, translate_func
     )
 
     # drop ``` or ~~~
     filtered_code_lines = [
         line
-        for line in translated_code_lines
+        for line in translated_code_lines.split("\n")
         if not (line.strip().startswith("```") or line.strip().startswith("~~~"))
     ]
     translated_code = "\n".join(filtered_code_lines)
@@ -126,6 +175,17 @@ def translate_html_block(html_block, chunk_size, translate_func, results):
     """
     from bs4 import BeautifulSoup
 
+    # if this is a short and simple tag, just translate it
+    if (
+        html_block.count("<") < 5
+        and html_block.count(">") < 5
+        and html_block.count("<") == html_block.count(">")
+        and len(html_block) < chunk_size
+    ):
+        translated = translate_func(html_block)
+        results.append(translated)
+        return
+
     soup = BeautifulSoup(html_block, "html.parser")
 
     # collect text nodes