3D障碍物目标框(中心点坐标XYZ、长宽高lwh、朝向角theta)的非极大值抑制
#include <iostream>
#include <vector>
#include <algorithm>
#include <opencv2/opencv.hpp>
struct BoundingBox3D
{
double centerX, centerY, centerZ;
double length, width, height;
double theta;
double score;
BoundingBox3D(double x, double y, double z, double l, double w, double h, double t, double s)
: centerX(x), centerY(y), centerZ(z), length(l), width(w), height(h), theta(t), score(s) {}
};
class NMS3D
{
public:
NMS3D(double iouThreshold) : iouThreshold_(iouThreshold) {}
std::vector<BoundingBox3D> executeNMS(const std::vector<BoundingBox3D> &boxes)
{
std::vector<BoundingBox3D> resultBoxes;
std::vector<BoundingBox3D> sortedBoxes = sortBoxesByScore(boxes);
while (!sortedBoxes.empty())
{
BoundingBox3D topBox = sortedBoxes[0];
resultBoxes.push_back(topBox);
sortedBoxes.erase(sortedBoxes.begin());
sortedBoxes = removeOverlappingBoxes(topBox, sortedBoxes);
}
return resultBoxes;
}
private:
std::vector<BoundingBox3D> sortBoxesByScore(const std::vector<BoundingBox3D> &boxes)
{
std::vector<BoundingBox3D> sortedBoxes = boxes;
std::sort(sortedBoxes.begin(), sortedBoxes.end(),
[](const BoundingBox3D &a, const BoundingBox3D &b)
{
return a.score > b.score;
});
return sortedBoxes;
}
std::vector<BoundingBox3D> removeOverlappingBoxes(const BoundingBox3D &box,
const std::vector<BoundingBox3D> &boxes)
{
std::vector<BoundingBox3D> filteredBoxes;
for (const auto &b : boxes)
{
if (calculateIoU(box, b) < iouThreshold_)
{
filteredBoxes.push_back(b);
}
}
return filteredBoxes;
}
double calculateIoU(const BoundingBox3D &box1, const BoundingBox3D &box2)
{
double intersectionVolume = calculateIntersectionVolume(box1, box2);
double unionVolume = box1.length * box1.width * box1.height +
box2.length * box2.width * box2.height -
intersectionVolume;
return intersectionVolume / unionVolume;
}
double calculateIntersectionVolume(const BoundingBox3D &box1, const BoundingBox3D &box2)
{
double intersectArea = calIntersectionArea(box1, box2);
double intersectHeight = calculateOverlap(box1.centerZ, box1.height, box2.centerZ, box2.height);
return intersectArea * intersectHeight;
}
cv::Point rotatePoint(const cv::Point &point, double angle)
{
double rotatedX = point.x * cos(angle) - point.y * sin(angle);
double rotatedY = point.x * sin(angle) + point.y * cos(angle);
return cv::Point(rotatedX, rotatedY);
}
double calIntersectionArea(const BoundingBox3D &box1, const BoundingBox3D &box2)
{
std::vector<cv::Point> triangle1 = {rotatePoint(cv::Point(box1.centerX - box1.width / 2, box1.centerY - box1.height / 2), box1.theta),
rotatePoint(cv::Point(box1.centerX + box1.width / 2, box1.centerY - box1.height / 2), box1.theta),
rotatePoint(cv::Point(box1.centerX + box1.width / 2, box1.centerY + box1.height / 2), box1.theta),
rotatePoint(cv::Point(box1.centerX - box1.width / 2, box1.centerY + box1.height / 2), box1.theta)};
std::vector<cv::Point> triangle2 = {rotatePoint(cv::Point(box2.centerX - box2.width / 2, box2.centerY - box2.height / 2), box2.theta),
rotatePoint(cv::Point(box2.centerX + box2.width / 2, box2.centerY - box2.height / 2), box2.theta),
rotatePoint(cv::Point(box2.centerX + box2.width / 2, box2.centerY + box2.height / 2), box2.theta),
rotatePoint(cv::Point(box2.centerX - box2.width / 2, box2.centerY + box2.height / 2), box2.theta)};
int height = -1;
int width = -1;
for (auto point : triangle1)
{
if (width < point.x)
width = point.x;
if (height < point.y)
height = point.y;
}
for (auto point : triangle2)
{
if (width < point.x)
width = point.x;
if (height < point.y)
height = point.y;
}
cv::Mat img = cv::Mat::zeros(height, width, CV_8UC3);
cv::fillConvexPoly(img, polygon1, cv::Scalar(0, 0, 1));
cv::fillConvexPoly(img, polygon2, cv::Scalar(0, 1, 0));
cv::Mat intersection = cv::Mat::zeros(img.size(), img.type());
cv::bitwise_and(img, img, intersection);
double union_area = cv::sum(intersection)[0];
}
double calculateOverlap(double center1, double size1, double center2, double size2)
{
double halfSize1 = size1 / 2;
double halfSize2 = size2 / 2;
double min1 = center1 - halfSize1;
double max1 = center1 + halfSize1;
double min2 = center2 - halfSize2;
double max2 = center2 + halfSize2;
return std::max(0.0, std::min(max1, max2) - std::max(min1, min2));
}
double iouThreshold_;
};
int main()
{
std::vector<BoundingBox3D> inputBoxes;
inputBoxes.push_back(BoundingBox3D(0.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0, 0.9));
inputBoxes.push_back(BoundingBox3D(0.1, 0.1, 0.1, 2.0, 1.0, 1.0, 0, 0.8));
inputBoxes.push_back(BoundingBox3D(2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 0, 0.7));
double iouThreshold = 0.5;
NMS3D nms(iouThreshold);
std::vector<BoundingBox3D> resultBoxes = nms.executeNMS(inputBoxes);
for (const auto &box : resultBoxes)
{
std::cout << "Center: (" << box.centerX << ", " << box.centerY << ", " << box.centerZ << "), "
<< "Dimensions: (" << box.length << ", " << box.width << ", " << box.height << "), "
<< "Theta: " << box.theta << ", "
<< "Score: " << box.score << std::endl;
}
return 0;
}