| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- import re
- import itertools
- import html
- from typing import Any, Dict, List
- from pydantic import (
- BaseModel,
- computed_field,
- model_validator,
- )
- class TableCell(BaseModel):
- """TableCell."""
- row_span: int = 1
- col_span: int = 1
- start_row_offset_idx: int
- end_row_offset_idx: int
- start_col_offset_idx: int
- end_col_offset_idx: int
- text: str
- column_header: bool = False
- row_header: bool = False
- row_section: bool = False
- @model_validator(mode="before")
- @classmethod
- def from_dict_format(cls, data: Any) -> Any:
- """from_dict_format."""
- if isinstance(data, Dict):
- # Check if this is a native BoundingBox or a bbox from docling-ibm-models
- if (
- # "bbox" not in data
- # or data["bbox"] is None
- # or isinstance(data["bbox"], BoundingBox)
- "text"
- in data
- ):
- return data
- text = data["bbox"].get("token", "")
- if not len(text):
- text_cells = data.pop("text_cell_bboxes", None)
- if text_cells:
- for el in text_cells:
- text += el["token"] + " "
- text = text.strip()
- data["text"] = text
- return data
- class TableData(BaseModel): # TBD
- """BaseTableData."""
- table_cells: List[TableCell] = []
- num_rows: int = 0
- num_cols: int = 0
- @computed_field # type: ignore
- @property
- def grid(
- self,
- ) -> List[List[TableCell]]:
- """grid."""
- # Initialise empty table data grid (only empty cells)
- table_data = [
- [
- TableCell(
- text="",
- start_row_offset_idx=i,
- end_row_offset_idx=i + 1,
- start_col_offset_idx=j,
- end_col_offset_idx=j + 1,
- )
- for j in range(self.num_cols)
- ]
- for i in range(self.num_rows)
- ]
- # Overwrite cells in table data for which there is actual cell content.
- for cell in self.table_cells:
- for i in range(
- min(cell.start_row_offset_idx, self.num_rows),
- min(cell.end_row_offset_idx, self.num_rows),
- ):
- for j in range(
- min(cell.start_col_offset_idx, self.num_cols),
- min(cell.end_col_offset_idx, self.num_cols),
- ):
- table_data[i][j] = cell
- return table_data
- """
- OTSL
- """
- OTSL_NL = "<nl>"
- OTSL_FCEL = "<fcel>"
- OTSL_ECEL = "<ecel>"
- OTSL_LCEL = "<lcel>"
- OTSL_UCEL = "<ucel>"
- OTSL_XCEL = "<xcel>"
- def otsl_extract_tokens_and_text(s: str):
- # Pattern to match anything enclosed by < >
- # (including the angle brackets themselves)
- # pattern = r"(<[^>]+>)"
- pattern = r"(" + r"|".join([OTSL_NL, OTSL_FCEL, OTSL_ECEL, OTSL_LCEL, OTSL_UCEL, OTSL_XCEL]) + r")"
- # Find all tokens (e.g. "<otsl>", "<loc_140>", etc.)
- tokens = re.findall(pattern, s)
- # Remove any tokens that start with "<loc_"
- tokens = [token for token in tokens]
- # Split the string by those tokens to get the in-between text
- text_parts = re.split(pattern, s)
- text_parts = [token for token in text_parts]
- # Remove any empty or purely whitespace strings from text_parts
- text_parts = [part for part in text_parts if part.strip()]
- return tokens, text_parts
- def otsl_parse_texts(texts, tokens):
- split_word = OTSL_NL
- split_row_tokens = [
- list(y)
- for x, y in itertools.groupby(tokens, lambda z: z == split_word)
- if not x
- ]
- table_cells = []
- r_idx = 0
- c_idx = 0
- # Check and complete the matrix
- if split_row_tokens:
- max_cols = max(len(row) for row in split_row_tokens)
- # Insert additional <ecel> to tags
- for row_idx, row in enumerate(split_row_tokens):
- while len(row) < max_cols:
- row.append(OTSL_ECEL)
- # Insert additional <ecel> to texts
- new_texts = []
- text_idx = 0
- for row_idx, row in enumerate(split_row_tokens):
- for col_idx, token in enumerate(row):
- new_texts.append(token)
- if text_idx < len(texts) and texts[text_idx] == token:
- text_idx += 1
- if (text_idx < len(texts) and
- texts[text_idx] not in [OTSL_NL, OTSL_FCEL, OTSL_ECEL, OTSL_LCEL, OTSL_UCEL, OTSL_XCEL]):
- new_texts.append(texts[text_idx])
- text_idx += 1
- new_texts.append(OTSL_NL)
- if text_idx < len(texts) and texts[text_idx] == OTSL_NL:
- text_idx += 1
- texts = new_texts
- def count_right(tokens, c_idx, r_idx, which_tokens):
- span = 0
- c_idx_iter = c_idx
- while tokens[r_idx][c_idx_iter] in which_tokens:
- c_idx_iter += 1
- span += 1
- if c_idx_iter >= len(tokens[r_idx]):
- return span
- return span
- def count_down(tokens, c_idx, r_idx, which_tokens):
- span = 0
- r_idx_iter = r_idx
- while tokens[r_idx_iter][c_idx] in which_tokens:
- r_idx_iter += 1
- span += 1
- if r_idx_iter >= len(tokens):
- return span
- return span
- for i, text in enumerate(texts):
- cell_text = ""
- if text in [
- OTSL_FCEL,
- OTSL_ECEL,
- ]:
- row_span = 1
- col_span = 1
- right_offset = 1
- if text != OTSL_ECEL:
- cell_text = texts[i + 1]
- right_offset = 2
- # Check next element(s) for lcel / ucel / xcel,
- # set properly row_span, col_span
- next_right_cell = ""
- if i + right_offset < len(texts):
- next_right_cell = texts[i + right_offset]
- next_bottom_cell = ""
- if r_idx + 1 < len(split_row_tokens):
- if c_idx < len(split_row_tokens[r_idx + 1]):
- next_bottom_cell = split_row_tokens[r_idx + 1][c_idx]
- if next_right_cell in [
- OTSL_LCEL,
- OTSL_XCEL,
- ]:
- # we have horisontal spanning cell or 2d spanning cell
- col_span += count_right(
- split_row_tokens,
- c_idx + 1,
- r_idx,
- [OTSL_LCEL, OTSL_XCEL],
- )
- if next_bottom_cell in [
- OTSL_UCEL,
- OTSL_XCEL,
- ]:
- # we have a vertical spanning cell or 2d spanning cell
- row_span += count_down(
- split_row_tokens,
- c_idx,
- r_idx + 1,
- [OTSL_UCEL, OTSL_XCEL],
- )
- table_cells.append(
- TableCell(
- text=cell_text.strip(),
- row_span=row_span,
- col_span=col_span,
- start_row_offset_idx=r_idx,
- end_row_offset_idx=r_idx + row_span,
- start_col_offset_idx=c_idx,
- end_col_offset_idx=c_idx + col_span,
- )
- )
- if text in [
- OTSL_FCEL,
- OTSL_ECEL,
- OTSL_LCEL,
- OTSL_UCEL,
- OTSL_XCEL,
- ]:
- c_idx += 1
- if text == OTSL_NL:
- r_idx += 1
- c_idx = 0
- return table_cells, split_row_tokens
- def export_to_html(table_data: TableData):
- nrows = table_data.num_rows
- ncols = table_data.num_cols
- text = ""
- if len(table_data.table_cells) == 0:
- return ""
- body = ""
- grid = table_data.grid
- for i in range(nrows):
- body += "<tr>"
- for j in range(ncols):
- cell: TableCell = grid[i][j]
- rowspan, rowstart = (
- cell.row_span,
- cell.start_row_offset_idx,
- )
- colspan, colstart = (
- cell.col_span,
- cell.start_col_offset_idx,
- )
- if rowstart != i:
- continue
- if colstart != j:
- continue
- content = html.escape(cell.text.strip())
- celltag = "td"
- if cell.column_header:
- celltag = "th"
- opening_tag = f"{celltag}"
- if rowspan > 1:
- opening_tag += f' rowspan="{rowspan}"'
- if colspan > 1:
- opening_tag += f' colspan="{colspan}"'
- body += f"<{opening_tag}>{content}</{celltag}>"
- body += "</tr>"
- # dir = get_text_direction(text)
- body = f"<table>{body}</table>"
- return body
- def convert_otsl_to_html(otsl_content: str):
- tokens, mixed_texts = otsl_extract_tokens_and_text(otsl_content)
- table_cells, split_row_tokens = otsl_parse_texts(mixed_texts, tokens)
- table_data = TableData(
- num_rows=len(split_row_tokens),
- num_cols=(
- max(len(row) for row in split_row_tokens) if split_row_tokens else 0
- ),
- table_cells=table_cells,
- )
- return export_to_html(table_data)
|