|
|
@@ -921,7 +921,7 @@ def coco_error_analysis(eval_details_file=None,
|
|
|
|
|
|
"""
|
|
|
|
|
|
- from multiprocessing import Pool
|
|
|
+ import multiprocessing as mp
|
|
|
from pycocotools.coco import COCO
|
|
|
from pycocotools.cocoeval import COCOeval
|
|
|
|
|
|
@@ -968,10 +968,11 @@ def coco_error_analysis(eval_details_file=None,
|
|
|
ps = np.vstack([ps, np.zeros((4, *ps.shape[1:]))])
|
|
|
catIds = cocoGt.getCatIds()
|
|
|
recThrs = cocoEval.params.recThrs
|
|
|
- with Pool(processes=48) as pool:
|
|
|
- args = [(k, cocoDt, cocoGt, catId, iou_type)
|
|
|
- for k, catId in enumerate(catIds)]
|
|
|
- analyze_results = pool.starmap(analyze_individual_category, args)
|
|
|
+ thread_num = mp.cpu_count() if mp.cpu_count() < 8 else 8
|
|
|
+ thread_pool = mp.pool.ThreadPool(thread_num)
|
|
|
+ args = [(k, cocoDt, cocoGt, catId, iou_type)
|
|
|
+ for k, catId in enumerate(catIds)]
|
|
|
+ analyze_results = thread_pool.starmap(analyze_individual_category, args)
|
|
|
for k, catId in enumerate(catIds):
|
|
|
nm = cocoGt.loadCats(catId)[0]
|
|
|
logging.info('--------------saving {}-{}---------------'.format(
|
|
|
@@ -996,6 +997,7 @@ def coco_error_analysis(eval_details_file=None,
|
|
|
makeplot(recThrs, ps[:, :, k], res_out_dir, nm['name'], iou_type)
|
|
|
makeplot(recThrs, ps, res_out_dir, 'allclass', iou_type)
|
|
|
|
|
|
+ np.linspace = fixed_linspace
|
|
|
coco_gt = COCO()
|
|
|
coco_gt.dataset = gt
|
|
|
coco_gt.createIndex()
|
|
|
@@ -1006,4 +1008,5 @@ def coco_error_analysis(eval_details_file=None,
|
|
|
if pred_mask is not None:
|
|
|
coco_dt = loadRes(coco_gt, pred_mask)
|
|
|
_analyze_results(coco_gt, coco_dt, res_type='segm', out_dir=save_dir)
|
|
|
+ np.linspace = backup_linspace
|
|
|
logging.info("The analysis figures are saved in {}".format(save_dir))
|