|
|
@@ -12,7 +12,7 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
import math
|
|
|
-
|
|
|
+import bisect
|
|
|
import numpy as np
|
|
|
|
|
|
from ..components import convert_points_to_boxes
|
|
|
@@ -203,30 +203,35 @@ def get_html_result(
|
|
|
str: Generated HTML content as a string.
|
|
|
"""
|
|
|
pred_html = []
|
|
|
- td_index = 0
|
|
|
- td_count = 0
|
|
|
- matched_list_index = 0
|
|
|
+ # 全局单元格序号,从 0 开始,和 table_cells_flag 使用同一坐标系
|
|
|
+ td_global = 0
|
|
|
+
|
|
|
head_structure = pred_structures[0:3]
|
|
|
html = "".join(head_structure)
|
|
|
table_structure = pred_structures[3:-3]
|
|
|
for tag in table_structure:
|
|
|
- matched_index = all_matched_index[matched_list_index]
|
|
|
if "</td>" in tag:
|
|
|
+ # 通过全局 td 序号定位当前“行索引”和“列索引”
|
|
|
+ # table_cells_flag 是每行起始单元格的前缀计数(已 append 了末尾)
|
|
|
+ row_idx = max(0, bisect.bisect_right(table_cells_flag, td_global) - 1)
|
|
|
+ col_idx = td_global - table_cells_flag[row_idx]
|
|
|
+ matched_index = all_matched_index[row_idx] if row_idx < len(all_matched_index) else {}
|
|
|
+
|
|
|
if "<td></td>" == tag:
|
|
|
- pred_html.extend("<td>")
|
|
|
- if td_index in matched_index.keys():
|
|
|
- if len(matched_index[td_index]) == 0:
|
|
|
+ pred_html.append("<td>")
|
|
|
+ if col_idx in matched_index.keys():
|
|
|
+ if len(matched_index[col_idx]) == 0:
|
|
|
continue
|
|
|
b_with = False
|
|
|
if (
|
|
|
- "<b>" in ocr_contents[matched_index[td_index][0]]
|
|
|
- and len(matched_index[td_index]) > 1
|
|
|
+ "<b>" in ocr_contents[matched_index[col_idx][0]]
|
|
|
+ and len(matched_index[col_idx]) > 1
|
|
|
):
|
|
|
b_with = True
|
|
|
- pred_html.extend("<b>")
|
|
|
- for i, td_index_index in enumerate(matched_index[td_index]):
|
|
|
+ pred_html.append("<b>")
|
|
|
+ for i, td_index_index in enumerate(matched_index[col_idx]):
|
|
|
content = ocr_contents[td_index_index]
|
|
|
- if len(matched_index[td_index]) > 1:
|
|
|
+ if len(matched_index[col_idx]) > 1:
|
|
|
if len(content) == 0:
|
|
|
continue
|
|
|
if content[0] == " ":
|
|
|
@@ -237,23 +242,17 @@ def get_html_result(
|
|
|
content = content[:-4]
|
|
|
if len(content) == 0:
|
|
|
continue
|
|
|
- if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
|
|
|
+ if i != len(matched_index[col_idx]) - 1 and " " != content[-1]:
|
|
|
content += " "
|
|
|
- pred_html.extend(content)
|
|
|
+ pred_html.append(content)
|
|
|
if b_with:
|
|
|
- pred_html.extend("</b>")
|
|
|
+ pred_html.append("</b>")
|
|
|
if "<td></td>" == tag:
|
|
|
pred_html.append("</td>")
|
|
|
else:
|
|
|
pred_html.append(tag)
|
|
|
- td_index += 1
|
|
|
- td_count += 1
|
|
|
- if (
|
|
|
- td_count >= table_cells_flag[matched_list_index + 1]
|
|
|
- and matched_list_index < len(all_matched_index) - 1
|
|
|
- ):
|
|
|
- matched_list_index += 1
|
|
|
- td_index = 0
|
|
|
+ # 推进到下一个全局单元格
|
|
|
+ td_global += 1
|
|
|
else:
|
|
|
pred_html.append(tag)
|
|
|
html += "".join(pred_html)
|
|
|
@@ -408,6 +407,23 @@ def map_and_get_max(table_cells_flag, row_start_index):
|
|
|
return max_values
|
|
|
|
|
|
|
|
|
+def build_structure_from_cells(table_cells_flag: list) -> list:
|
|
|
+ """
|
|
|
+ 用单元格检测的行起始标记(前缀和)构造一个简单的表格结构tokens:
|
|
|
+ head(3项) + [<tr>, <td></td>*n, </tr>]*R + end(3项)
|
|
|
+ """
|
|
|
+ head = ["<html>", "<body>", "<table>"]
|
|
|
+ body = []
|
|
|
+ for r in range(len(table_cells_flag) - 1):
|
|
|
+ body.append("<tr>")
|
|
|
+ cols = table_cells_flag[r + 1] - table_cells_flag[r]
|
|
|
+ for _ in range(cols):
|
|
|
+ body.append("<td></td>")
|
|
|
+ body.append("</tr>")
|
|
|
+ end = ["</table>", "</body>", "</html>"]
|
|
|
+ return head + body + end
|
|
|
+
|
|
|
+
|
|
|
def get_table_recognition_res(
|
|
|
table_box: list,
|
|
|
table_structure_result: list,
|
|
|
@@ -475,9 +491,24 @@ def get_table_recognition_res(
|
|
|
matched_index = match_table_and_ocr(
|
|
|
table_cells_result, ocr_dt_boxes, table_cells_flag, table_cells_flag
|
|
|
)
|
|
|
- pred_html = get_html_result(
|
|
|
- matched_index, ocr_texts_res, table_structure_result, row_start_index
|
|
|
- )
|
|
|
+ # 对齐检测列数与结构列数,若不一致则回退到基于cells的结构骨架
|
|
|
+ use_cells_skeleton = False
|
|
|
+ for i in range(len(table_cells_flag) - 1):
|
|
|
+ cols_cells = table_cells_flag[i + 1] - table_cells_flag[i]
|
|
|
+ cols_struct = row_start_index[i + 1] - row_start_index[i]
|
|
|
+ if cols_cells != cols_struct:
|
|
|
+ use_cells_skeleton = True
|
|
|
+ break
|
|
|
+
|
|
|
+ if use_cells_skeleton:
|
|
|
+ skeleton = build_structure_from_cells(table_cells_flag)
|
|
|
+ pred_html = get_html_result(
|
|
|
+ matched_index, ocr_texts_res, skeleton, table_cells_flag
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ pred_html = get_html_result(
|
|
|
+ matched_index, ocr_texts_res, table_structure_result, row_start_index
|
|
|
+ )
|
|
|
|
|
|
single_img_res = {
|
|
|
"cell_box_list": table_cells_result,
|