ppTableModel.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import cv2
  2. from paddleocr.ppstructure.table.predict_table import TableSystem
  3. from paddleocr.ppstructure.utility import init_args
  4. from magic_pdf.libs.Constants import *
  5. import os
  6. from PIL import Image
  7. import numpy as np
  8. class ppTableModel(object):
  9. """
  10. This class is responsible for converting image of table into HTML format using a pre-trained model.
  11. Attributes:
  12. - table_sys: An instance of TableSystem initialized with parsed arguments.
  13. Methods:
  14. - __init__(config): Initializes the model with configuration parameters.
  15. - img2html(image): Converts a PIL Image or NumPy array to HTML string.
  16. - parse_args(**kwargs): Parses configuration arguments.
  17. """
  18. def __init__(self, config):
  19. """
  20. Parameters:
  21. - config (dict): Configuration dictionary containing model_dir and device.
  22. """
  23. args = self.parse_args(**config)
  24. self.table_sys = TableSystem(args)
  25. def img2html(self, image):
  26. """
  27. Parameters:
  28. - image (PIL.Image or np.ndarray): The image of the table to be converted.
  29. Return:
  30. - HTML (str): A string representing the HTML structure with content of the table.
  31. """
  32. if isinstance(image, Image.Image):
  33. image = np.asarray(image)
  34. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  35. pred_res, _ = self.table_sys(image)
  36. pred_html = pred_res["html"]
  37. # res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
  38. # "</table></body></html>","") + "</table></td>\n"
  39. return pred_html
  40. def parse_args(self, **kwargs):
  41. parser = init_args()
  42. model_dir = kwargs.get("model_dir")
  43. table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR)
  44. table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT)
  45. det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR)
  46. rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
  47. rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
  48. device = kwargs.get("device", "cpu")
  49. use_gpu = True if device.startswith("cuda") else False
  50. config = {
  51. "use_gpu": use_gpu,
  52. "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
  53. "table_algorithm": "TableMaster",
  54. "table_model_dir": table_model_dir,
  55. "table_char_dict_path": table_char_dict_path,
  56. "det_model_dir": det_model_dir,
  57. "rec_model_dir": rec_model_dir,
  58. "rec_char_dict_path": rec_char_dict_path,
  59. }
  60. parser.set_defaults(**config)
  61. return parser.parse_args([])