ppTableModel.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from paddleocr.ppstructure.table.predict_table import TableSystem
  2. from paddleocr.ppstructure.utility import init_args
  3. from magic_pdf.libs.Constants import *
  4. import os
  5. from PIL import Image
  6. import numpy as np
  7. class ppTableModel(object):
  8. """
  9. This class is responsible for converting image of table into HTML format using a pre-trained model.
  10. Attributes:
  11. - table_sys: An instance of TableSystem initialized with parsed arguments.
  12. Methods:
  13. - __init__(config): Initializes the model with configuration parameters.
  14. - img2html(image): Converts a PIL Image or NumPy array to HTML string.
  15. - parse_args(**kwargs): Parses configuration arguments.
  16. """
  17. def __init__(self, config):
  18. """
  19. Parameters:
  20. - config (dict): Configuration dictionary containing model_dir and device.
  21. """
  22. args = self.parse_args(**config)
  23. self.table_sys = TableSystem(args)
  24. def img2html(self, image):
  25. """
  26. Parameters:
  27. - image (PIL.Image or np.ndarray): The image of the table to be converted.
  28. Return:
  29. - HTML (str): A string representing the HTML structure with content of the table.
  30. """
  31. if isinstance(image, Image.Image):
  32. image = np.array(image)
  33. pred_res, _ = self.table_sys(image)
  34. pred_html = pred_res["html"]
  35. res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace("</table></body></html>",
  36. "") + "</table></td>\n"
  37. return res
  38. def parse_args(self, **kwargs):
  39. parser = init_args()
  40. model_dir = kwargs.get("model_dir")
  41. table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR)
  42. table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT)
  43. det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR)
  44. rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
  45. rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
  46. device = kwargs.get("device", "cpu")
  47. use_gpu = True if device == "cuda" else False
  48. config = {
  49. "use_gpu": use_gpu,
  50. "table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
  51. "table_algorithm": TABLE_MASTER,
  52. "table_model_dir": table_model_dir,
  53. "table_char_dict_path": table_char_dict_path,
  54. "det_model_dir": det_model_dir,
  55. "rec_model_dir": rec_model_dir,
  56. "rec_char_dict_path": rec_char_dict_path,
  57. }
  58. parser.set_defaults(**config)
  59. return parser.parse_args([])