utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. def _is_sentence_dot(text, i):
  16. """
  17. Check if the given character is a sentence ending punctuation.
  18. """
  19. # if the character is not a period, return False
  20. if text[i] != ".":
  21. return False
  22. # previous character
  23. prev = text[i - 1] if i > 0 else ""
  24. # next character
  25. next = text[i + 1] if i + 1 < len(text) else ""
  26. # previous is digit or letter, then not sentence ending punctuation
  27. if prev.isdigit() or prev.isalpha():
  28. return False
  29. # next is digit or letter, then not sentence ending punctuation
  30. if next.isdigit() or next.isalpha():
  31. return False
  32. # next is a punctuation, then sentence ending punctuation
  33. if next in ("", " ", "\t", "\n", '"', "'", "”", "’", ")", "】", "」", "》"):
  34. return True
  35. return False
  36. def _find_split_pos(text, chunk_size):
  37. """
  38. Find the position to split the text into two chunks.
  39. Args:
  40. text (str): The original text to be split.
  41. chunk_size (int): The maximum size of each chunk.
  42. Returns:
  43. int: The index where the text should be split.
  44. """
  45. center = len(text) // 2
  46. split_chars = ["\n", "。", ";", ";", "!", "!", "?", "?"]
  47. # Search forward
  48. for i in range(center, len(text)):
  49. if text[i] in split_chars:
  50. # Check for whitespace around the split character
  51. j = i + 1
  52. while j < len(text) and text[j] in " \t\n":
  53. j += 1
  54. if j < len(text) and len(text[:j]) <= chunk_size:
  55. return i, j
  56. elif text[i] == "." and _is_sentence_dot(text, i):
  57. j = i + 1
  58. while j < len(text) and text[j] in " \t\n":
  59. j += 1
  60. if j < len(text) and len(text[:j]) <= chunk_size:
  61. return i, j
  62. # Search backward
  63. for i in range(center, 0, -1):
  64. if text[i] in split_chars:
  65. j = i + 1
  66. while j < len(text) and text[j] in " \t\n":
  67. j += 1
  68. if len(text[:j]) <= chunk_size:
  69. return i, j
  70. elif text[i] == "." and _is_sentence_dot(text, i):
  71. j = i + 1
  72. while j < len(text) and text[j] in " \t\n":
  73. j += 1
  74. if len(text[:j]) <= chunk_size:
  75. return i, j
  76. # If no suitable position is found, split directly
  77. return min(chunk_size, len(text)), min(chunk_size, len(text))
  78. def split_text_recursive(text, chunk_size, translate_func):
  79. """
  80. Split the text recursively and translate each chunk.
  81. Args:
  82. text (str): The original text to be split.
  83. chunk_size (int): The maximum size of each chunk.
  84. translate_func (callable): A function that translates a single chunk of text.
  85. results (list): A list to store the translated chunks.
  86. Returns:
  87. None
  88. """
  89. text = text.strip()
  90. if len(text) <= chunk_size:
  91. return translate_func(text)
  92. else:
  93. split_pos, end_whitespace = _find_split_pos(text, chunk_size)
  94. left = text[:split_pos]
  95. right = text[end_whitespace:]
  96. whitespace = text[split_pos:end_whitespace]
  97. if left:
  98. left_text = split_text_recursive(left, chunk_size, translate_func)
  99. if right:
  100. right_text = split_text_recursive(right, chunk_size, translate_func)
  101. return left_text + whitespace + right_text
  102. def translate_code_block(code_block, chunk_size, translate_func, results):
  103. """
  104. Translate a code block and append the result to the results list.
  105. Args:
  106. code_block (str): The code block to be translated.
  107. chunk_size (int): The maximum size of each chunk.
  108. translate_func (callable): A function that translates a single chunk of text.
  109. results (list): A list to store the translated chunks.
  110. Returns:
  111. None
  112. """
  113. lines = code_block.strip().split("\n")
  114. if lines[0].startswith("```") or lines[0].startswith("~~~"):
  115. header = lines[0]
  116. footer = (
  117. lines[-1]
  118. if (lines[-1].startswith("```") or lines[-1].startswith("~~~"))
  119. else ""
  120. )
  121. code_content = "\n".join(lines[1:-1]) if footer else "\n".join(lines[1:])
  122. else:
  123. header = ""
  124. footer = ""
  125. code_content = code_block
  126. translated_code_lines = split_text_recursive(
  127. code_content, chunk_size, translate_func
  128. )
  129. # drop ``` or ~~~
  130. filtered_code_lines = [
  131. line
  132. for line in translated_code_lines.split("\n")
  133. if not (line.strip().startswith("```") or line.strip().startswith("~~~"))
  134. ]
  135. translated_code = "\n".join(filtered_code_lines)
  136. result = f"{header}\n{translated_code}\n{footer}" if header else translated_code
  137. results.append(result)
  138. def translate_html_block(html_block, chunk_size, translate_func, results):
  139. """
  140. Translate a HTML block and append the result to the results list.
  141. Args:
  142. html_block (str): The HTML block to be translated.
  143. chunk_size (int): The maximum size of each chunk.
  144. translate_func (callable): A function that translates a single chunk of text.
  145. results (list): A list to store the translated chunks.
  146. Returns:
  147. None
  148. """
  149. from bs4 import BeautifulSoup
  150. # if this is a short and simple tag, just translate it
  151. if (
  152. html_block.count("<") < 5
  153. and html_block.count(">") < 5
  154. and html_block.count("<") == html_block.count(">")
  155. and len(html_block) < chunk_size
  156. ):
  157. translated = translate_func(html_block)
  158. results.append(translated)
  159. return
  160. soup = BeautifulSoup(html_block, "html.parser")
  161. # collect text nodes
  162. text_nodes = []
  163. for node in soup.find_all(string=True, recursive=True):
  164. text = node.strip()
  165. if text:
  166. text_nodes.append(node)
  167. idx = 0
  168. total = len(text_nodes)
  169. while idx < total:
  170. batch_nodes = []
  171. li_texts = []
  172. current_length = len("<ol></ol>")
  173. while idx < total:
  174. node_text = text_nodes[idx].strip()
  175. if len(node_text) > chunk_size:
  176. # if node_text is too long, split it
  177. translated_lines = []
  178. split_text_recursive(
  179. node_text, chunk_size, translate_func, translated_lines
  180. )
  181. # concatenate translated lines with \n
  182. text_nodes[idx].replace_with("\n".join(translated_lines))
  183. idx += 1
  184. continue
  185. li_str = f"<li>{node_text}</li>"
  186. if current_length + len(li_str) > chunk_size:
  187. break
  188. batch_nodes.append(text_nodes[idx])
  189. li_texts.append(li_str)
  190. current_length += len(li_str)
  191. idx += 1
  192. if not batch_nodes:
  193. # if all individual nodes are longer than chunk_size, translate it alone
  194. node_text = text_nodes[idx - 1].strip()
  195. li_str = f"<li>{node_text}</li>"
  196. batch_nodes = [text_nodes[idx - 1]]
  197. li_texts = [li_str]
  198. if batch_nodes:
  199. batch_text = "<ol>" + "".join(li_texts) + "</ol>"
  200. translated = translate_func(batch_text)
  201. trans_soup = BeautifulSoup(translated, "html.parser")
  202. translated_lis = trans_soup.find_all("li")
  203. for orig_node, li_tag in zip(batch_nodes, translated_lis):
  204. orig_node.replace_with(li_tag.decode_contents())
  205. results.append(str(soup))
  206. def split_original_texts(text):
  207. """
  208. Split the original text into chunks.
  209. """
  210. from bs4 import BeautifulSoup
  211. # find all html blocks and replace them with placeholders
  212. soup = BeautifulSoup(text, "html.parser")
  213. html_blocks = []
  214. html_placeholders = []
  215. placeholder_fmt = "<<HTML_BLOCK_{}>>"
  216. text_after_placeholder = ""
  217. index = 0
  218. for elem in soup.contents:
  219. if hasattr(elem, "name") and elem.name is not None:
  220. html_str = str(elem)
  221. placeholder = placeholder_fmt.format(index)
  222. html_blocks.append(html_str)
  223. html_placeholders.append(placeholder)
  224. text_after_placeholder += placeholder
  225. index += 1
  226. else:
  227. text_after_placeholder += str(elem)
  228. # split text into paragraphs
  229. splited_block = []
  230. splited_block = split_and_append_text(splited_block, text_after_placeholder)
  231. # replace placeholders with html blocks
  232. current_index = 0
  233. for idx, block in enumerate(splited_block):
  234. _, content = block
  235. while (
  236. current_index < len(html_placeholders)
  237. and html_placeholders[current_index] in content
  238. ):
  239. content = content.replace(
  240. html_placeholders[current_index], html_blocks[current_index]
  241. )
  242. current_index += 1
  243. splited_block[idx] = ("html", content)
  244. return splited_block
  245. def split_and_append_text(result, text_content):
  246. """
  247. Split the text and append the result to the result list.
  248. Args:
  249. result (list): The current result list.
  250. text_content (str): The text content to be processed.
  251. Returns:
  252. list: The updated result list after processing the text content.
  253. """
  254. if text_content.strip():
  255. # match all code block interval
  256. code_pattern = re.compile(r"(```.*?\n.*?```|~~~.*?\n.*?~~~)", re.DOTALL)
  257. last_pos = 0
  258. for m in code_pattern.finditer(text_content):
  259. # process text before code block
  260. if m.start() > last_pos:
  261. non_code = text_content[last_pos : m.start()]
  262. paragraphs = re.split(r"\n{2,}", non_code)
  263. for p in paragraphs:
  264. if p.strip():
  265. result.append(("text", p.strip()))
  266. # process code block
  267. result.append(("code", m.group()))
  268. last_pos = m.end()
  269. # process remaining text
  270. if last_pos < len(text_content):
  271. non_code = text_content[last_pos:]
  272. paragraphs = re.split(r"\n{2,}", non_code)
  273. for p in paragraphs:
  274. if p.strip():
  275. result.append(("text", p.strip()))
  276. return result