SampleMNIST 类
我们先来看下SampleMNIST类,SampleMNIST类主要包含build() 和 infer()两个方法,
- build(): 将Caffe model 使用tensorRT优化器进行优化转换为 TensorRT object,
因此需要指定 网络模型文件(如caffe的deploy.prototxt)、训练好的权值文件(如caffe的net.caffemodel)以及均值文件(如caffe的mean.binaryproto)
此外,还需要指定 batch size,并标记输入输出层。 - infer():运行TensorRT 执行推理计算。
//!
//! SampleMNIST 类
//!
class SampleMNIST
{
template <typename T>
using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>;
public:
SampleMNIST(const MNISTSampleParams& params)
: mParams(params)
{
}
//! 构建网络 engine
bool build();
//! 运行TensorRT推断
bool infer();
//! 清空类残余
bool teardown();
... ...
};
build() 转换模型
bool SampleMNIST::build()
//!
//! 构建网络network,配置builder以及构建CudaEngine
//!
bool SampleMNIST::build()
{
//创建一个 builder,gLogger打印日志。
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(gLogger));
if (!builder)
return false;
//创建一个 空的network对象,后面会填充
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork());
if (!network)
return false;
//创建一个caffe模型解析对象parser
auto parser = SampleUniquePtr<nvcaffeparser1::ICaffeParser>(nvcaffeparser1::createCaffeParser());
if (!parser)
return false;
//parser 调用解析函数,填充network对象
constructNetwork(builder, network, parser);
builder->setMaxBatchSize(mParams.batchSize);
builder->setMaxWorkspaceSize(16_MB);
builder->allowGPUFallback(true);
samplesCommon::enableDLA(builder.get(), mParams.dlaCore);
//使用network创建 CudaEngine,优化方法在这里执行。
mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(builder->buildCudaEngine(*network), samplesCommon::InferDeleter());
if (!mEngine)
return false;
//至此,caffe模型已转换为tensorRT object。
assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();
assert(mInputDims.nbDims == 3);
return true;
}
其中 用到的constructNetwork函数为:
//!
//! 使用caffe parser 构建MNIST网络并标记输出
//!
void SampleMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser)
{
//创建一个caffe模型解析对象,parser,并调用解析函数,填充network对象,
//将caffe模型中的blob解析为tensorRT中的tensor,赋给blob_name_to_tensor变量。
//此处使用了模型文件和权值文件。
const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor = parser->parse(
locateFile(mParams.prototxtFileName, mParams.dataDirs).c_str(),
locateFile(mParams.weightsFileName, mParams.dataDirs).c_str(),
*network,
nvinfer1::DataType::kFLOAT);
//标记输出blob (可能有多个输出)
for (auto& s : mParams.outputTensorNames)
network->markOutput(*blobNameToTensor->find(s.c_str()));
// 加载均值文件,将读取的图片统一减去平均值。
Dims inputDims = network->getInput(0)->getDimensions();
mMeanBlob = SampleUniquePtr<nvcaffeparser1::IBinaryProtoBlob>(parser->parseBinaryProto(locateFile(mParams.meanFileName, mParams.dataDirs).c_str()));
Weights meanWeights{DataType::kFLOAT, mMeanBlob->getData(), inputDims.d[1] * inputDims.d[2]};
auto mean = network->addConstant(Dims3(1, inputDims.d[1], inputDims.d[2]), meanWeights);
auto meanSub = network->addElementWise(*network->getInput(0), *mean->getOutput(0), ElementWiseOperation::kSUB);
network->getLayer(0)->setInput(0, *meanSub->getOutput(0));
}
infer() 执行推理计算
bool SampleMNIST::infer()
//!
//! 运行 TensorRT inference engine
//!
bool SampleMNIST::infer()
{
// BufferManager 缓存管理器负责缓存分配和撤销
samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);
//创建上下文环境,主要用于inference 函数中启动cuda核
auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
if (!context)
return false;
// 随机读入一张图片
srand(time(NULL));
const int digit = rand() % 10;
assert(mParams.inputTensorNames.size() == 1);
if (!processInput(buffers, mParams.inputTensorNames[0], digit))
return false;
//创建cuda流
cudaStream_t stream;
CHECK(cudaStreamCreate(&stream));
//同步host缓存中数据到device缓存
buffers.copyInputToDeviceAsync(stream);
//启动cuda核,异步执行推理计算
if (!context->enqueue(mParams.batchSize, buffers.getDeviceBindings().data(), stream, nullptr))
return false;
// 将device中计算结果同步回host
buffers.copyOutputToHostAsync(stream);
// 同步不同的cuda流
cudaStreamSynchronize(stream);
//销毁流对象
cudaStreamDestroy(stream);
// 校验输出结果
assert(mParams.outputTensorNames.size() == 1);
bool outputCorrect = verifyOutput(buffers, mParams.outputTensorNames[0], digit);
return outputCorrect;
}
主函数
int main(int argc, char** argv)
{
samplesCommon::Args args;
if (!samplesCommon::parseArgs(args, argc, argv))
{
if (args.help)
{
printHelpInfo();
return EXIT_SUCCESS;
}
return EXIT_FAILURE;
}
MNISTSampleParams params = initializeSampleParams(args);
SampleMNIST sample(params);
std::cout << "Building and running a GPU inference engine for MNIST" << std::endl;
if (!sample.build())
return EXIT_FAILURE;
if (!sample.infer())
return EXIT_FAILURE;
if (!sample.teardown())
return EXIT_FAILURE;
}
其中initializeSampleParams函数如下:
//!
//! 根据命令行初始化参数
//!
MNISTSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
MNISTSampleParams params;
if (args.dataDirs.size() != 0) //用户指定目录
params.dataDirs = args.dataDirs;
else //默认目录 data/mnist/ 或者 data/samples/mnist/
{
params.dataDirs.push_back("data/mnist/");
params.dataDirs.push_back("data/samples/mnist/");
}
params.prototxtFileName = "mnist.prototxt"; //prototxt文件
params.weightsFileName = "mnist.caffemodel"; //caffemodel文件
params.meanFileName = "mnist_mean.binaryproto"; //均值文件
params.inputTensorNames.push_back("data"); //输入层名称
params.batchSize = 1; //batchSize
params.outputTensorNames.push_back("prob"); //输出层名称
params.dlaCore = args.useDLACore; //
return params;
}
参考资料
1 TensorRT Installation Guide https://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html
2 TensorRT Documentation https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/index.html
3 Best Practices For TensorRT Performance https://docs.nvidia.com/deeplearning/sdk/tensorrt-best-practices/index.html
4 TensorRT Developer Guide https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html
5 TensorRT API https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/index.html
6 Samples Support Guide https://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html
7 “Hello World” For TensorRThttps://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html#mnist_sample