소스 검색

modify the base and vis

sunyanfang01 5 년 전
부모
커밋
120db901e4
2개의 변경된 파일13개의 추가작업 그리고 6개의 파일을 삭제
  1. 2 0
      paddlex/cv/models/base.py
  2. 11 6
      paddlex/cv/models/slim/visualize.py

+ 2 - 0
paddlex/cv/models/base.py

@@ -371,6 +371,8 @@ class BaseAPI:
                    use_vdl=False,
                    early_stop=False,
                    early_stop_patience=5):
+        if train_dataset.num_samples < train_batch_size:
+            raise Exception('The amount of training datset must be larger than batch size.')
         if not osp.isdir(save_dir):
             if osp.exists(save_dir):
                 os.remove(save_dir)

+ 11 - 6
paddlex/cv/models/slim/visualize.py

@@ -30,7 +30,6 @@ def visualize(model, sensitivities_file, save_dir='./'):
     import matplotlib
     matplotlib.use('Agg')
     import matplotlib.pyplot as plt
-
     program = model.test_prog
     place = model.places[0]
     fig = plt.figure()
@@ -51,15 +50,21 @@ def visualize(model, sensitivities_file, save_dir='./'):
         min(np.array(x)) - 0.01,
         max(np.array(x)) + 0.01, 0.05)
     my_y_ticks = np.arange(0.05, 1, 0.05)
-    plt.xticks(my_x_ticks, fontsize=3)
-    plt.yticks(my_y_ticks, fontsize=3)
+    plt.xticks(my_x_ticks, rotation=30, fontsize=8)
+    plt.yticks(my_y_ticks, fontsize=8)
     for a, b in zip(x, y):
         plt.text(
             a,
-            b, (float('%0.4f' % a), float('%0.3f' % b)),
+            b, (float('%0.3f' % a), float('%0.3f' % b)),
             ha='center',
             va='bottom',
-            fontsize=3)
+            fontsize=8)
+    plt.rcParams['savefig.dpi'] = 120
+    plt.rcParams['figure.dpi'] = 150
     suffix = osp.splitext(sensitivities_file)[-1]
-    plt.savefig('sensitivities.png', dpi=800)
+    plt.savefig(osp.join(save_dir, 'sensitivities.png'))
     plt.close()
+    import pickle
+    coor = dict(zip(x, y))
+    output = open(osp.join(save_dir, 'sensitivities_xy.pkl'), 'wb')
+    pickle.dump(coor, output)