struct_eqtable.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. import torch
  2. from struct_eqtable import build_model
  3. from magic_pdf.model.sub_modules.table.table_utils import minify_html
  4. class StructTableModel:
  5. def __init__(self, model_path, max_new_tokens=1024, max_time=60):
  6. # init
  7. assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
  8. self.model = build_model(
  9. model_ckpt=model_path,
  10. max_new_tokens=max_new_tokens,
  11. max_time=max_time,
  12. lmdeploy=False,
  13. flash_attn=False,
  14. batch_size=1,
  15. ).cuda()
  16. self.default_format = "html"
  17. def predict(self, images, output_format=None, **kwargs):
  18. if output_format is None:
  19. output_format = self.default_format
  20. else:
  21. if output_format not in ['latex', 'markdown', 'html']:
  22. raise ValueError(f"Output format {output_format} is not supported.")
  23. results = self.model(
  24. images, output_format=output_format
  25. )
  26. if output_format == "html":
  27. results = [minify_html(html) for html in results]
  28. return results