Walking in the Latent Space

On BatchedNMS Implementation in Torchvision

Recently I was going through torchvision’s batched nms implementation. Batched NMS lives in torchvision/ops/boxes.py in a function named batched_nms. These lines caught my attention:

if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
    return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
else:
    return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)

When you perform batched nms, torchvision actually chooses between two implementations based on number of input boxes. For larger inputs, it calls _batched_nms_vanilla which is defined in the same file. This function works as you would expect, it separates boxes by class id and then performs nms for each group. _batched_nms_coordinate_trick is more interesting as it uses a smart trick to do multi-class nms in a single nms pass.

Let’s suppose you run an object detector and got several bounding boxes for each class (3 in this example). Now you need to perform NMS to get rid of duplicate overlapping boxes. Normally, you would do NMS for each class separately. That is if there are 2 overlapping boxes belonging to different classes, neither should be eliminated as they designated different objects.

What _batched_nms_coordinate_trick does is it adds offsets to bounding boxes of the same class which gurantees zero overlap between boxes of different classes.

Then, these offsets are subtracted from the resulting coordinates and we get the same result.

There is actually a whole thread on this issue. I decided to run my own benchmark and got this result for 80 classes (averaged over 3 trials):

_batched_nms_coordinate_trick is actually faster for smaller inputs (though on my PC the threshold seems to be around 2000 boxes rather than 1000).

This is an interesting way to do batched nms, however it seems counter-intuitive. Torchvision implements NMS in \( O(n^2) \) time. Math tells us that for \( O(n^2) \) algorithm an increase in input size by a factor \( t \) will result in \( (tn)^2 = t^2(n^2) \) increase in runtime. On the other hand, running the algorithm \( t\) times separately gives us \( t(n^2) \) which means linear complexity with respect to the number of classes.

My best guess here is that _batched_nms_vanilla has an overhead due to tensor masking and conditionals which are relatively expensive operations, though this requires a more in-depth profiling of PyTorch’s code which I might do in the future.