|
|
@@ -314,7 +314,8 @@ class CustomPEKModel:
|
|
|
mfr_res = []
|
|
|
for mf_img in dataloader:
|
|
|
mf_img = mf_img.to(self.device)
|
|
|
- output = self.mfr_model.generate({'image': mf_img})
|
|
|
+ with torch.no_grad():
|
|
|
+ output = self.mfr_model.generate({'image': mf_img})
|
|
|
mfr_res.extend(output['pred_str'])
|
|
|
for res, latex in zip(latex_filling_list, mfr_res):
|
|
|
res['latex'] = latex_rm_whitespace(latex)
|
|
|
@@ -336,7 +337,14 @@ class CustomPEKModel:
|
|
|
elif int(res['category_id']) in [5]:
|
|
|
table_res_list.append(res)
|
|
|
|
|
|
- clean_memory()
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ properties = torch.cuda.get_device_properties(self.device)
|
|
|
+ total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
|
|
|
+ if total_memory <= 8:
|
|
|
+ gc_start = time.time()
|
|
|
+ clean_memory()
|
|
|
+ gc_time = round(time.time() - gc_start, 2)
|
|
|
+ logger.info(f"gc time: {gc_time}")
|
|
|
|
|
|
# ocr识别
|
|
|
if self.apply_ocr:
|