utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. def _is_sentence_dot(text, i):
  16. """
  17. Check if the given character is a sentence ending punctuation.
  18. """
  19. # if the character is not a period, return False
  20. if text[i] != ".":
  21. return False
  22. # previous character
  23. prev = text[i - 1] if i > 0 else ""
  24. # next character
  25. next = text[i + 1] if i + 1 < len(text) else ""
  26. # previous is digit or letter, then not sentence ending punctuation
  27. if prev.isdigit() or prev.isalpha():
  28. return False
  29. # next is digit or letter, then not sentence ending punctuation
  30. if next.isdigit() or next.isalpha():
  31. return False
  32. # next is a punctuation, then sentence ending punctuation
  33. if next in ("", " ", "\t", "\n", '"', "'", "”", "’", ")", "】", "」", "》"):
  34. return True
  35. return False
  36. def _find_split_pos(text, chunk_size):
  37. """
  38. Find the position to split the text into two chunks.
  39. Args:
  40. text (str): The original text to be split.
  41. chunk_size (int): The maximum size of each chunk.
  42. Returns:
  43. int: The index where the text should be split.
  44. """
  45. center = len(text) // 2
  46. split_chars = ["\n", "。", ";", ";", "!", "!", "?", "?"]
  47. # Search forward
  48. for i in range(center, len(text)):
  49. if text[i] in split_chars:
  50. # Check for whitespace around the split character
  51. j = i + 1
  52. while j < len(text) and text[j] in " \t\n":
  53. j += 1
  54. if j < len(text) and len(text[:j]) <= chunk_size:
  55. return i, j
  56. elif text[i] == "." and _is_sentence_dot(text, i):
  57. j = i + 1
  58. while j < len(text) and text[j] in " \t\n":
  59. j += 1
  60. if j < len(text) and len(text[:j]) <= chunk_size:
  61. return i, j
  62. # Search backward
  63. for i in range(center, 0, -1):
  64. if text[i] in split_chars:
  65. j = i + 1
  66. while j < len(text) and text[j] in " \t\n":
  67. j += 1
  68. if len(text[:j]) <= chunk_size:
  69. return i, j
  70. elif text[i] == "." and _is_sentence_dot(text, i):
  71. j = i + 1
  72. while j < len(text) and text[j] in " \t\n":
  73. j += 1
  74. if len(text[:j]) <= chunk_size:
  75. return i, j
  76. # If no suitable position is found, split directly
  77. return min(chunk_size, len(text)), min(chunk_size, len(text))
  78. def split_text_recursive(text, chunk_size, translate_func):
  79. """
  80. Split the text recursively and translate each chunk.
  81. Args:
  82. text (str): The original text to be split.
  83. chunk_size (int): The maximum size of each chunk.
  84. translate_func (callable): A function that translates a single chunk of text.
  85. results (list): A list to store the translated chunks.
  86. Returns:
  87. None
  88. """
  89. text = text.strip()
  90. if len(text) <= chunk_size:
  91. return translate_func(text)
  92. else:
  93. split_pos, end_whitespace = _find_split_pos(text, chunk_size)
  94. left = text[:split_pos]
  95. right = text[end_whitespace:]
  96. whitespace = text[split_pos:end_whitespace]
  97. if left:
  98. left_text = split_text_recursive(left, chunk_size, translate_func)
  99. if right:
  100. right_text = split_text_recursive(right, chunk_size, translate_func)
  101. return left_text + whitespace + right_text
  102. def translate_code_block(code_block, chunk_size, translate_func, results):
  103. """
  104. Translate a code block and append the result to the results list.
  105. Args:
  106. code_block (str): The code block to be translated.
  107. chunk_size (int): The maximum size of each chunk.
  108. translate_func (callable): A function that translates a single chunk of text.
  109. results (list): A list to store the translated chunks.
  110. Returns:
  111. None
  112. """
  113. lines = code_block.strip().split("\n")
  114. if lines[0].startswith("```") or lines[0].startswith("~~~"):
  115. header = lines[0]
  116. footer = (
  117. lines[-1]
  118. if (lines[-1].startswith("```") or lines[-1].startswith("~~~"))
  119. else ""
  120. )
  121. code_content = "\n".join(lines[1:-1]) if footer else "\n".join(lines[1:])
  122. else:
  123. header = ""
  124. footer = ""
  125. code_content = code_block
  126. translated_code_lines = split_text_recursive(
  127. code_content, chunk_size, translate_func
  128. )
  129. # drop ``` or ~~~
  130. filtered_code_lines = [
  131. line
  132. for line in translated_code_lines.split("\n")
  133. if not (line.strip().startswith("```") or line.strip().startswith("~~~"))
  134. ]
  135. translated_code = "\n".join(filtered_code_lines)
  136. result = f"{header}\n{translated_code}\n{footer}" if header else translated_code
  137. results.append(result)
  138. def translate_html_block(html_block, chunk_size, translate_func, results):
  139. """
  140. Translate a HTML block and append the result to the results list.
  141. Args:
  142. html_block (str): The HTML block to be translated.
  143. chunk_size (int): The maximum size of each chunk.
  144. translate_func (callable): A function that translates a single chunk of text.
  145. results (list): A list to store the translated chunks.
  146. Returns:
  147. None
  148. """
  149. from bs4 import BeautifulSoup
  150. # if this is a short and simple tag, just translate it
  151. if (
  152. html_block.count("<") < 5
  153. and html_block.count(">") < 5
  154. and html_block.count("<") == html_block.count(">")
  155. and len(html_block) < chunk_size
  156. ):
  157. translated = translate_func(html_block)
  158. results.append(translated)
  159. return
  160. soup = BeautifulSoup(html_block, "html.parser")
  161. # collect text nodes
  162. text_nodes = []
  163. for node in soup.find_all(string=True, recursive=True):
  164. text = node.strip()
  165. if text:
  166. text_nodes.append(node)
  167. idx = 0
  168. total = len(text_nodes)
  169. while idx < total:
  170. batch_nodes = []
  171. li_texts = []
  172. current_length = len("<ol></ol>")
  173. while idx < total:
  174. node_text = text_nodes[idx].strip()
  175. if len(node_text) > chunk_size:
  176. # if node_text is too long, split it
  177. translated_text = split_text_recursive(
  178. node_text, chunk_size, translate_func
  179. )
  180. # concatenate translated lines with \n
  181. text_nodes[idx].replace_with(translated_text)
  182. idx += 1
  183. continue
  184. li_str = f"<li>{node_text}</li>"
  185. if current_length + len(li_str) > chunk_size:
  186. break
  187. batch_nodes.append(text_nodes[idx])
  188. li_texts.append(li_str)
  189. current_length += len(li_str)
  190. idx += 1
  191. if not batch_nodes:
  192. # if all individual nodes are longer than chunk_size, translate it alone
  193. node_text = text_nodes[idx - 1].strip()
  194. li_str = f"<li>{node_text}</li>"
  195. batch_nodes = [text_nodes[idx - 1]]
  196. li_texts = [li_str]
  197. if batch_nodes:
  198. batch_text = "<ol>" + "".join(li_texts) + "</ol>"
  199. translated = translate_func(batch_text)
  200. trans_soup = BeautifulSoup(translated, "html.parser")
  201. translated_lis = trans_soup.find_all("li")
  202. for orig_node, li_tag in zip(batch_nodes, translated_lis):
  203. orig_node.replace_with(li_tag.decode_contents())
  204. results.append(str(soup))
  205. def split_original_texts(text):
  206. """
  207. Split the original text into chunks.
  208. """
  209. from bs4 import BeautifulSoup
  210. # find all html blocks and replace them with placeholders
  211. soup = BeautifulSoup(text, "html.parser")
  212. html_blocks = []
  213. html_placeholders = []
  214. placeholder_fmt = "<<HTML_BLOCK_{}>>"
  215. text_after_placeholder = ""
  216. index = 0
  217. for elem in soup.contents:
  218. if hasattr(elem, "name") and elem.name is not None:
  219. html_str = str(elem)
  220. placeholder = placeholder_fmt.format(index)
  221. html_blocks.append(html_str)
  222. html_placeholders.append(placeholder)
  223. text_after_placeholder += placeholder
  224. index += 1
  225. else:
  226. text_after_placeholder += str(elem)
  227. # split text into paragraphs
  228. splited_block = []
  229. splited_block = split_and_append_text(splited_block, text_after_placeholder)
  230. # replace placeholders with html blocks
  231. current_index = 0
  232. for idx, block in enumerate(splited_block):
  233. _, content = block
  234. while (
  235. current_index < len(html_placeholders)
  236. and html_placeholders[current_index] in content
  237. ):
  238. content = content.replace(
  239. html_placeholders[current_index], html_blocks[current_index]
  240. )
  241. current_index += 1
  242. splited_block[idx] = ("html", content)
  243. return splited_block
  244. def split_and_append_text(result, text_content):
  245. """
  246. Split the text and append the result to the result list.
  247. Args:
  248. result (list): The current result list.
  249. text_content (str): The text content to be processed.
  250. Returns:
  251. list: The updated result list after processing the text content.
  252. """
  253. if text_content.strip():
  254. # match all code block interval
  255. code_pattern = re.compile(r"(```.*?\n.*?```|~~~.*?\n.*?~~~)", re.DOTALL)
  256. last_pos = 0
  257. for m in code_pattern.finditer(text_content):
  258. # process text before code block
  259. if m.start() > last_pos:
  260. non_code = text_content[last_pos : m.start()]
  261. paragraphs = re.split(r"\n{2,}", non_code)
  262. for p in paragraphs:
  263. if p.strip():
  264. result.append(("text", p.strip()))
  265. # process code block
  266. result.append(("code", m.group()))
  267. last_pos = m.end()
  268. # process remaining text
  269. if last_pos < len(text_content):
  270. non_code = text_content[last_pos:]
  271. paragraphs = re.split(r"\n{2,}", non_code)
  272. for p in paragraphs:
  273. if p.strip():
  274. result.append(("text", p.strip()))
  275. return result