Selaa lähdekoodia

Modify LaTeX OCR data processing to support CI (#1927)

* modify the data process in latexocr and change the config yaml

* merge train and val folder in dataset into image
liuhongen1234567 1 vuosi sitten
vanhempi
commit
f592c8c8bd

+ 2 - 4
docs/tutorials/data/dataset_check.md

@@ -491,11 +491,9 @@ tar -xf ./dataset/ocr_rec_latexocr_dataset_example.tar -C ./dataset/
 在对数据集校验时,只需一行命令:
 
 ```bash
-python main.py -c paddlex/configs/text_recognition/LaTeX_OCR_rec.yml \
+python main.py -c paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml \
     -o Global.mode=check_dataset \
-    -o Global.dataset_dir=./dataset/ocr_rec_latexocr_dataset_example \
-    -o CheckDataset.convert.enable=True \
-    -o CheckDataset.convert.src_dataset_type=PKL
+    -o Global.dataset_dir=./dataset/ocr_rec_latexocr_dataset_example 
 ```
 
 执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息。命令运行成功后会在log中打印出 `Check dataset passed !` 信息,同时相关产出会保存在当前目录的 `./output/check_dataset` 目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。校验结果文件保存在 `./output/check_dataset_result.json`,校验结果文件具体内容为

+ 3 - 3
paddlex/configs/text_recognition/LaTeX_OCR_rec.yml → paddlex/configs/formula_recognition/LaTeX_OCR_rec.yaml

@@ -7,15 +7,15 @@ Global:
 
 CheckDataset:
   convert: 
-    enable: False
-    src_dataset_type: null
+    enable: True
+    src_dataset_type: PKL
   split: 
     enable: False
     train_percent: null
     val_percent: null
 
 Train:
-  epochs_iters: 2
+  epochs_iters: 20
   batch_size_train: 40
   batch_size_val: 10
   learning_rate: 0.0001

+ 1 - 4
paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py

@@ -80,10 +80,7 @@ def check(dataset_dir,
                                 "in {file_list} should be {valid_num_parts} (current delimiter is '{delim}')."
                             )
                         file_name = substr[0]
-                        if dataset_type == 'LaTeXOCRDataset':
-                           img_path = osp.join(dataset_dir, tag, file_name)                        
-                        else:
-                            img_path = osp.join(dataset_dir, file_name)
+                        img_path = osp.join(dataset_dir, file_name)
                         if len(sample_paths[tag]) < max_recorded_sample_cnts:
                             sample_paths[tag].append(
                                 os.path.relpath(img_path, output))

+ 5 - 4
paddlex/modules/text_recognition/dataset_checker/dataset_src/convert_dataset.py

@@ -55,13 +55,14 @@ def convert(dataset_type, input_dir):
 
 def convert_pkl_dataset(root_dir):
     for anno in ['train.txt','val.txt']:
-        src_img_dir = os.path.join(root_dir, anno.replace(".txt",""))
+        src_img_dir = root_dir
         src_anno_path = os.path.join(root_dir, anno)
         txt2pickle(src_img_dir, src_anno_path, root_dir)
 
 def txt2pickle(images, equations, save_dir):
     imagesize = try_import("imagesize")
-    save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(images.split("/")[-1]))
+    phase = os.path.basename(equations).replace(".txt","")
+    save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(phase))
     min_dimensions = (32, 32)
     max_dimensions = (672, 192)
     max_length = 512
@@ -73,7 +74,7 @@ def txt2pickle(images, equations, save_dir):
             for l in tqdm(lines, total = len(lines)):
                 l = l.strip()
                 img_name, equation = l.split("\t")
-                img_path = os.path.join( os.path.abspath(images), img_name)
+                img_path = os.path.join(images,img_name)
                 width, height = imagesize.get(img_path)
                 if (
                     min_dimensions[0] <= width <= max_dimensions[0]
@@ -81,7 +82,7 @@ def txt2pickle(images, equations, save_dir):
                 ):
                     divide_h = math.ceil(height / 16) * 16
                     divide_w = math.ceil(width / 16) * 16
-                    data[(divide_w, divide_h)].append((equation, img_path))
+                    data[(divide_w, divide_h)].append((equation, img_name))
                     pic_num +=1
         data = dict(data)
         with open(save_p, "wb") as file: