Update: Some of TensorRT plugins were released as open source. Old version of NMS is located at https://github.com/NVIDIA/TensorRT/tree/master/plugin/nmsPlugin and new version can be found at https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin .
While trying to convert Tensorflow detection network to TensorRT I needed to either implement new non-maximum suppression layer or to use NVIDIAs createNMSPlugin layer. After quick look trough the poor documentation (1, 2) I was forced to just experiment with the layer by feeding it different size inputs and hoping to get it working.
After few hours of frustrating trial and error experimentation I implemented new NMS plugin layer using 3. Unfortunately this implementation did not have very good performance which I think is because of synchronization operations used in 3. 4 would probably have been better alternative as it is lower level and leaves synchronization for the developer.
Fortunately with help of coworkers we were finally able figure out proper inputs for the plugin. Hopefully this helps some one else.
Inputs
- Prediction locations from the network. Shape of the tensor needs to be
[4*number_of_boxes, 1, 1]
. Or[4*number_of_boxes*number_of_classes, 1, 1]
ifshareLocation
parameter is set to true inDetectionOutputParameters
. - Class confidence tensor. Shape
[number_of_classes*number_of_boxes, 1, 1]
- Prior box locations and variances. Shape
[2, number_of_boxes*4, 1]
.
You can change the input order by setting inputOrder
parameter in DetectionOutputParameters
.
Outputs
- Final prediction boxes. Shape
[1, keep_topk, 7]
.[ImageId, Label, Confidence, Xmin, Ymin, Xmax, Ymax]
- Number of valid boxes
[1,1,1]
. This value is int32.
Input format of tensors
Prior box locations and variances must have format:
[[box1_prior1, box1_prior2, box1_prior3, box1_prior4]
[box2_prior1, box2_prior2, box2_prior3, box2_prior4]
...
[box1_variance1, box1_variance2, box1_variance3, box1_variance4]
[box2_variance1, box2_variance2, box2_variance3, box2_variance4]
...]
The box format can be chosen using codeType
in DetectionOutputParameters
.