format_utils.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import re
  2. import itertools
  3. import html
  4. from typing import Any, Dict, List
  5. from pydantic import (
  6. BaseModel,
  7. computed_field,
  8. model_validator,
  9. )
  10. class TableCell(BaseModel):
  11. """TableCell."""
  12. row_span: int = 1
  13. col_span: int = 1
  14. start_row_offset_idx: int
  15. end_row_offset_idx: int
  16. start_col_offset_idx: int
  17. end_col_offset_idx: int
  18. text: str
  19. column_header: bool = False
  20. row_header: bool = False
  21. row_section: bool = False
  22. @model_validator(mode="before")
  23. @classmethod
  24. def from_dict_format(cls, data: Any) -> Any:
  25. """from_dict_format."""
  26. if isinstance(data, Dict):
  27. # Check if this is a native BoundingBox or a bbox from docling-ibm-models
  28. if (
  29. # "bbox" not in data
  30. # or data["bbox"] is None
  31. # or isinstance(data["bbox"], BoundingBox)
  32. "text"
  33. in data
  34. ):
  35. return data
  36. text = data["bbox"].get("token", "")
  37. if not len(text):
  38. text_cells = data.pop("text_cell_bboxes", None)
  39. if text_cells:
  40. for el in text_cells:
  41. text += el["token"] + " "
  42. text = text.strip()
  43. data["text"] = text
  44. return data
  45. class TableData(BaseModel): # TBD
  46. """BaseTableData."""
  47. table_cells: List[TableCell] = []
  48. num_rows: int = 0
  49. num_cols: int = 0
  50. @computed_field # type: ignore
  51. @property
  52. def grid(
  53. self,
  54. ) -> List[List[TableCell]]:
  55. """grid."""
  56. # Initialise empty table data grid (only empty cells)
  57. table_data = [
  58. [
  59. TableCell(
  60. text="",
  61. start_row_offset_idx=i,
  62. end_row_offset_idx=i + 1,
  63. start_col_offset_idx=j,
  64. end_col_offset_idx=j + 1,
  65. )
  66. for j in range(self.num_cols)
  67. ]
  68. for i in range(self.num_rows)
  69. ]
  70. # Overwrite cells in table data for which there is actual cell content.
  71. for cell in self.table_cells:
  72. for i in range(
  73. min(cell.start_row_offset_idx, self.num_rows),
  74. min(cell.end_row_offset_idx, self.num_rows),
  75. ):
  76. for j in range(
  77. min(cell.start_col_offset_idx, self.num_cols),
  78. min(cell.end_col_offset_idx, self.num_cols),
  79. ):
  80. table_data[i][j] = cell
  81. return table_data
  82. """
  83. OTSL
  84. """
  85. OTSL_NL = "<nl>"
  86. OTSL_FCEL = "<fcel>"
  87. OTSL_ECEL = "<ecel>"
  88. OTSL_LCEL = "<lcel>"
  89. OTSL_UCEL = "<ucel>"
  90. OTSL_XCEL = "<xcel>"
  91. def otsl_extract_tokens_and_text(s: str):
  92. # Pattern to match anything enclosed by < >
  93. # (including the angle brackets themselves)
  94. # pattern = r"(<[^>]+>)"
  95. pattern = r"(" + r"|".join([OTSL_NL, OTSL_FCEL, OTSL_ECEL, OTSL_LCEL, OTSL_UCEL, OTSL_XCEL]) + r")"
  96. # Find all tokens (e.g. "<otsl>", "<loc_140>", etc.)
  97. tokens = re.findall(pattern, s)
  98. # Remove any tokens that start with "<loc_"
  99. tokens = [token for token in tokens]
  100. # Split the string by those tokens to get the in-between text
  101. text_parts = re.split(pattern, s)
  102. text_parts = [token for token in text_parts]
  103. # Remove any empty or purely whitespace strings from text_parts
  104. text_parts = [part for part in text_parts if part.strip()]
  105. return tokens, text_parts
  106. def otsl_parse_texts(texts, tokens):
  107. split_word = OTSL_NL
  108. split_row_tokens = [
  109. list(y)
  110. for x, y in itertools.groupby(tokens, lambda z: z == split_word)
  111. if not x
  112. ]
  113. table_cells = []
  114. r_idx = 0
  115. c_idx = 0
  116. def count_right(tokens, c_idx, r_idx, which_tokens):
  117. span = 0
  118. c_idx_iter = c_idx
  119. while tokens[r_idx][c_idx_iter] in which_tokens:
  120. c_idx_iter += 1
  121. span += 1
  122. if c_idx_iter >= len(tokens[r_idx]):
  123. return span
  124. return span
  125. def count_down(tokens, c_idx, r_idx, which_tokens):
  126. span = 0
  127. r_idx_iter = r_idx
  128. while tokens[r_idx_iter][c_idx] in which_tokens:
  129. r_idx_iter += 1
  130. span += 1
  131. if r_idx_iter >= len(tokens):
  132. return span
  133. return span
  134. for i, text in enumerate(texts):
  135. cell_text = ""
  136. if text in [
  137. OTSL_FCEL,
  138. OTSL_ECEL,
  139. ]:
  140. row_span = 1
  141. col_span = 1
  142. right_offset = 1
  143. if text != OTSL_ECEL:
  144. cell_text = texts[i + 1]
  145. right_offset = 2
  146. # Check next element(s) for lcel / ucel / xcel,
  147. # set properly row_span, col_span
  148. next_right_cell = ""
  149. if i + right_offset < len(texts):
  150. next_right_cell = texts[i + right_offset]
  151. next_bottom_cell = ""
  152. if r_idx + 1 < len(split_row_tokens):
  153. if c_idx < len(split_row_tokens[r_idx + 1]):
  154. next_bottom_cell = split_row_tokens[r_idx + 1][c_idx]
  155. if next_right_cell in [
  156. OTSL_LCEL,
  157. OTSL_XCEL,
  158. ]:
  159. # we have horisontal spanning cell or 2d spanning cell
  160. col_span += count_right(
  161. split_row_tokens,
  162. c_idx + 1,
  163. r_idx,
  164. [OTSL_LCEL, OTSL_XCEL],
  165. )
  166. if next_bottom_cell in [
  167. OTSL_UCEL,
  168. OTSL_XCEL,
  169. ]:
  170. # we have a vertical spanning cell or 2d spanning cell
  171. row_span += count_down(
  172. split_row_tokens,
  173. c_idx,
  174. r_idx + 1,
  175. [OTSL_UCEL, OTSL_XCEL],
  176. )
  177. table_cells.append(
  178. TableCell(
  179. text=cell_text.strip(),
  180. row_span=row_span,
  181. col_span=col_span,
  182. start_row_offset_idx=r_idx,
  183. end_row_offset_idx=r_idx + row_span,
  184. start_col_offset_idx=c_idx,
  185. end_col_offset_idx=c_idx + col_span,
  186. )
  187. )
  188. if text in [
  189. OTSL_FCEL,
  190. OTSL_ECEL,
  191. OTSL_LCEL,
  192. OTSL_UCEL,
  193. OTSL_XCEL,
  194. ]:
  195. c_idx += 1
  196. if text == OTSL_NL:
  197. r_idx += 1
  198. c_idx = 0
  199. return table_cells, split_row_tokens
  200. def export_to_html(table_data: TableData):
  201. nrows = table_data.num_rows
  202. ncols = table_data.num_cols
  203. text = ""
  204. if len(table_data.table_cells) == 0:
  205. return ""
  206. body = ""
  207. for i in range(nrows):
  208. body += "<tr>"
  209. for j in range(ncols):
  210. cell: TableCell = table_data.grid[i][j]
  211. rowspan, rowstart = (
  212. cell.row_span,
  213. cell.start_row_offset_idx,
  214. )
  215. colspan, colstart = (
  216. cell.col_span,
  217. cell.start_col_offset_idx,
  218. )
  219. if rowstart != i:
  220. continue
  221. if colstart != j:
  222. continue
  223. content = html.escape(cell.text.strip())
  224. celltag = "td"
  225. if cell.column_header:
  226. celltag = "th"
  227. opening_tag = f"{celltag}"
  228. if rowspan > 1:
  229. opening_tag += f' rowspan="{rowspan}"'
  230. if colspan > 1:
  231. opening_tag += f' colspan="{colspan}"'
  232. body += f"<{opening_tag}>{content}</{celltag}>"
  233. body += "</tr>"
  234. # dir = get_text_direction(text)
  235. body = f"<table>{body}</table>"
  236. return body
  237. def convert_otsl_to_html(otsl_content: str):
  238. tokens, mixed_texts = otsl_extract_tokens_and_text(otsl_content)
  239. table_cells, split_row_tokens = otsl_parse_texts(mixed_texts, tokens)
  240. table_data = TableData(
  241. num_rows=len(split_row_tokens),
  242. num_cols=(
  243. max(len(row) for row in split_row_tokens) if split_row_tokens else 0
  244. ),
  245. table_cells=table_cells,
  246. )
  247. return export_to_html(table_data)