@@ -92,21 +92,36 @@ def assign_priors(gt_boxes, gt_labels, corner_form_priors,
92
92
boxes (num_priors, 4): real values for priors.
93
93
labels (num_priros): labels for priors.
94
94
"""
95
- # size: num_priors x num_targets
95
+ # size: [num_priors, num_targets]
96
+ # 每行表示单个先验框与各个标注框的IoU
97
+ # 每列表示单个标注框与各个先验框的IoU
96
98
ious = iou_of (gt_boxes .unsqueeze (0 ), corner_form_priors .unsqueeze (1 ))
97
- # size: num_priors
99
+ # size: [num_priors]
100
+ # best_target_per_prior:每个先验框计算得到的最高IoU
101
+ # best_target_per_prior_index:每个先验框对应最高IoU的标注框下标
98
102
best_target_per_prior , best_target_per_prior_index = ious .max (1 )
99
- # size: num_targets
103
+ # size: [num_targets]
104
+ # best_prior_per_target:每个标注框计算得到的最高IoU
105
+ # best_prior_per_target_index:每个标注框对应最高IoU的先验框下标
100
106
best_prior_per_target , best_prior_per_target_index = ious .max (0 )
101
107
108
+ # 再一次确保标注框与最高IoU的先验框匹配
102
109
for target_index , prior_index in enumerate (best_prior_per_target_index ):
103
110
best_target_per_prior_index [prior_index ] = target_index
111
+
112
+ # size: [num_priors]
113
+ # 得到每个先验框对应标注框的标签/类别
114
+ labels = gt_labels [best_target_per_prior_index ]
115
+ # size: [num_priors, 4]
116
+ # 得到每个先验框对应标注框的坐标
117
+ boxes = gt_boxes [best_target_per_prior_index ]
118
+
104
119
# 2.0 is used to make sure every target has a prior assigned
120
+ # 确保每个标注框对应IoU最高的先验框的阈值大于iou_threshold
105
121
best_target_per_prior .index_fill_ (0 , best_prior_per_target_index , 2 )
106
- # size: num_priors
107
- labels = gt_labels [best_target_per_prior_index ]
122
+ # IoU小于iou_threshold的先验框设置为背景类别
108
123
labels [best_target_per_prior < iou_threshold ] = 0 # the backgournd id
109
- boxes = gt_boxes [ best_target_per_prior_index ]
124
+
110
125
return boxes , labels
111
126
112
127
0 commit comments