StructTableModel.py 960 B

12345678910111213141516171819202122232425262728293031
  1. import torch
  2. from struct_eqtable import build_model
  3. class StructTableModel:
  4. def __init__(self, model_path, max_new_tokens=1024, max_time=60):
  5. # init
  6. assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
  7. self.model = build_model(
  8. model_ckpt=model_path,
  9. max_new_tokens=max_new_tokens,
  10. max_time=max_time,
  11. lmdeploy=False,
  12. flash_attn=False,
  13. batch_size=1,
  14. ).cuda()
  15. self.default_format = "html"
  16. def predict(self, images, output_format=None, **kwargs):
  17. if output_format is None:
  18. output_format = self.default_format
  19. else:
  20. if output_format not in ['latex', 'markdown', 'html']:
  21. raise ValueError(f"Output format {output_format} is not supported.")
  22. results = self.model(
  23. images, output_format=output_format
  24. )
  25. return results