C++实践多目标跟踪ByteTrack
flyfish
论文《ByteTrack: Multi-Object Tracking by Associating Every Detection Box》
链接:https://pan.baidu.com/s/12W_TMnphUqLbE9-mWxlawQ
提取码:0000
包括已经转换好了的onnx模型。C++代码可以直接使用。
运行环境:Ubuntu 18.04
OpenCV 4.5.5
eigen-3.3.9
原作者使用了YOLOX的检测模型,这里使用了YOLOv5(6.1)。其他代码就是抄ByteTrack。
本着易用易于搭建环境编译的原则,只依赖了OpenCV和eigen
目标检测模型部分
如果不想自己动手,可以使用现成的,如果想自己动手转模型需要这样做
检测模型部分
从https://github.com/ultralytics/yolov5上下载6.1版本源码
和预训练模型
当搭建好了YOLOv5运行环境之后,运行转换命令 例如:
python export.py --weights yolov5x.pt --include onnx
这时候就有了yolov5x.onnx
C++部分
编译OpenCV
编译 eigen
unzip eigen-3.3.9.zip
cd eigen-3.3.9
mkdir build
cd build
cmake ..
sudo make install
main文件的代码
#include <fstream>
#include <sstream>
#include <opencv2/imgproc.hpp>
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>
#include "YOLOv5Detector.h"
#include "BYTETracker.h"
int main(int argc, char *argv[])
{
//-----------------------------------------------------------------------
// 加载类别名称
std::vector<std::string> classes;
std::string file="./coco_80_labels_list.txt";
std::ifstream ifs(file);
if (!ifs.is_open())
CV_Error(cv::Error::StsError, "File " + file + " not found");
std::string line;
while (std::getline(ifs, line))
{
classes.push_back(line);
}
//-----------------------------------------------------------------------
std::shared_ptr<YOLOv5Detector> detector(new YOLOv5Detector());
detector->init("./yolov5x.onnx");
cv::VideoCapture capture("rtsp://192.168.1.2/test");
std::vector<detect_result> results;
int fps=20;
BYTETracker tracker(fps, 30);
int num_frames = 0;
cv::VideoWriter video("out.avi",cv::VideoWriter::fourcc('M','J','P','G'),10, cv::Size(1920,1080));
while (true)
{
cv::Mat frame;
if (!capture.read(frame)) // if not success, break loop
{
std::cout<<"\n Cannot read the video file. please check your video.\n";
break;
}
num_frames ++;
//Second/Millisecond/Microsecond 秒s/毫秒ms/微秒us
auto start = std::chrono::system_clock::now();
detector->detect(frame, results);
auto end = std::chrono::system_clock::now();
auto detect_time =std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();//ms
std::vector<detect_result> objects;
for (detect_result dr : results)
{
cv::putText(frame, classes[dr.classId], cv::Point(dr.box.tl().x+10, dr.box.tl().y - 10), cv::FONT_HERSHEY_SIMPLEX, .8, cv::Scalar(0, 255, 0));
if(dr.classId == 0) //person
{
objects.push_back(dr);
}
}
start = std::chrono::system_clock::now();
std::vector<STrack> output_stracks = tracker.update(objects);
end = std::chrono::system_clock::now();
auto track_time =std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();//us
for (unsigned long i = 0; i < output_stracks.size(); i++)
{
std::vector<float> tlwh = output_stracks[i].tlwh;
bool vertical = tlwh[2] / tlwh[3] > 1.6;
if (tlwh[2] * tlwh[3] > 20 && !vertical)
{
cv::Scalar s = tracker.get_color(output_stracks[i].track_id);
cv::putText(frame, cv::format("%d", output_stracks[i].track_id), cv::Point(tlwh[0], tlwh[1] - 5),
0, 0.6, cv::Scalar(0, 0, 255), 2, cv::LINE_AA);
cv::rectangle(frame, cv::Rect(tlwh[0], tlwh[1], tlwh[2], tlwh[3]), s, 2);
}
}
cv::putText(frame, cv::format("detect ms:%ld # track us:%ld # current frame: %d",detect_time, track_time,num_frames),
cv::Point(1, 40), cv::FONT_HERSHEY_PLAIN, 2.0, cv::Scalar(255, 255, 255), 2, 8);
cv::imshow("YOLOv5-6.1", frame);
video.write(frame);
if(cv::waitKey(30) == 27) // Wait for 'esc' key press to exit
{
break;
}
results.clear();
}
capture.release();
video.release();
cv::destroyAllWindows();
}
可以Qt代码项目的方式打开CMakeLists.txt,Qt加载项目之后编译即可
bin文件存储着已经编译好的模型和类别名称文件