0
点赞
收藏
分享

微信扫一扫

TensorRT之第一个示例:mnist手写体识别

阎小妍 2022-08-08 阅读 81


SampleMNIST 类

我们先来看下SampleMNIST类,SampleMNIST类主要包含build() 和 infer()两个方法,

  • build(): 将Caffe model 使用tensorRT优化器进行优化转换为 TensorRT object,
    因此需要指定 网络模型文件(如caffe的deploy.prototxt)、训练好的权值文件(如caffe的net.caffemodel)以及均值文件(如caffe的mean.binaryproto)
    此外,还需要指定 batch size,并标记输入输出层。
  • infer():运行TensorRT 执行推理计算。

//!
//! SampleMNIST 类
//!
class SampleMNIST
{
template <typename T>
using SampleUniquePtr = std::unique_ptr<T, samplesCommon::InferDeleter>;

public:
SampleMNIST(const MNISTSampleParams& params)
: mParams(params)
{
}

//! 构建网络 engine
bool build();

//! 运行TensorRT推断
bool infer();

//! 清空类残余
bool teardown();
... ...
};

build() 转换模型

TensorRT之第一个示例:mnist手写体识别_缓存


bool SampleMNIST::build()

//!
//! 构建网络network,配置builder以及构建CudaEngine
//!
bool SampleMNIST::build()
{
//创建一个 builder,gLogger打印日志。
auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(gLogger));
if (!builder)
return false;

//创建一个 空的network对象,后面会填充
auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetwork());
if (!network)
return false;

//创建一个caffe模型解析对象parser
auto parser = SampleUniquePtr<nvcaffeparser1::ICaffeParser>(nvcaffeparser1::createCaffeParser());
if (!parser)
return false;

//parser 调用解析函数,填充network对象
constructNetwork(builder, network, parser);
builder->setMaxBatchSize(mParams.batchSize);
builder->setMaxWorkspaceSize(16_MB);
builder->allowGPUFallback(true);
samplesCommon::enableDLA(builder.get(), mParams.dlaCore);

//使用network创建 CudaEngine,优化方法在这里执行。
mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(builder->buildCudaEngine(*network), samplesCommon::InferDeleter());
if (!mEngine)
return false;

//至此,caffe模型已转换为tensorRT object。

assert(network->getNbInputs() == 1);
mInputDims = network->getInput(0)->getDimensions();
assert(mInputDims.nbDims == 3);

return true;
}

其中 用到的constructNetwork函数为:

//!
//! 使用caffe parser 构建MNIST网络并标记输出
//!
void SampleMNIST::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder, SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser)
{
//创建一个caffe模型解析对象,parser,并调用解析函数,填充network对象,
//将caffe模型中的blob解析为tensorRT中的tensor,赋给blob_name_to_tensor变量。
//此处使用了模型文件和权值文件。
const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor = parser->parse(
locateFile(mParams.prototxtFileName, mParams.dataDirs).c_str(),
locateFile(mParams.weightsFileName, mParams.dataDirs).c_str(),
*network,
nvinfer1::DataType::kFLOAT);

//标记输出blob (可能有多个输出)
for (auto& s : mParams.outputTensorNames)
network->markOutput(*blobNameToTensor->find(s.c_str()));

// 加载均值文件,将读取的图片统一减去平均值。
Dims inputDims = network->getInput(0)->getDimensions();
mMeanBlob = SampleUniquePtr<nvcaffeparser1::IBinaryProtoBlob>(parser->parseBinaryProto(locateFile(mParams.meanFileName, mParams.dataDirs).c_str()));
Weights meanWeights{DataType::kFLOAT, mMeanBlob->getData(), inputDims.d[1] * inputDims.d[2]};

auto mean = network->addConstant(Dims3(1, inputDims.d[1], inputDims.d[2]), meanWeights);
auto meanSub = network->addElementWise(*network->getInput(0), *mean->getOutput(0), ElementWiseOperation::kSUB);
network->getLayer(0)->setInput(0, *meanSub->getOutput(0));
}

infer() 执行推理计算

TensorRT之第一个示例:mnist手写体识别_TensorRT_02


bool SampleMNIST::infer()

//!
//! 运行 TensorRT inference engine
//!
bool SampleMNIST::infer()
{
// BufferManager 缓存管理器负责缓存分配和撤销
samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);

//创建上下文环境,主要用于inference 函数中启动cuda核
auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
if (!context)
return false;

// 随机读入一张图片
srand(time(NULL));
const int digit = rand() % 10;

assert(mParams.inputTensorNames.size() == 1);
if (!processInput(buffers, mParams.inputTensorNames[0], digit))
return false;

//创建cuda流
cudaStream_t stream;
CHECK(cudaStreamCreate(&stream));

//同步host缓存中数据到device缓存
buffers.copyInputToDeviceAsync(stream);

//启动cuda核,异步执行推理计算
if (!context->enqueue(mParams.batchSize, buffers.getDeviceBindings().data(), stream, nullptr))
return false;

// 将device中计算结果同步回host
buffers.copyOutputToHostAsync(stream);

// 同步不同的cuda流
cudaStreamSynchronize(stream);

//销毁流对象
cudaStreamDestroy(stream);

// 校验输出结果
assert(mParams.outputTensorNames.size() == 1);
bool outputCorrect = verifyOutput(buffers, mParams.outputTensorNames[0], digit);

return outputCorrect;
}

主函数

int main(int argc, char** argv)
{
samplesCommon::Args args;

if (!samplesCommon::parseArgs(args, argc, argv))
{
if (args.help)
{
printHelpInfo();
return EXIT_SUCCESS;
}
return EXIT_FAILURE;
}
MNISTSampleParams params = initializeSampleParams(args);

SampleMNIST sample(params);
std::cout << "Building and running a GPU inference engine for MNIST" << std::endl;

if (!sample.build())
return EXIT_FAILURE;

if (!sample.infer())
return EXIT_FAILURE;

if (!sample.teardown())
return EXIT_FAILURE;
}

其中initializeSampleParams函数如下:

//!
//! 根据命令行初始化参数
//!
MNISTSampleParams initializeSampleParams(const samplesCommon::Args& args)
{
MNISTSampleParams params;
if (args.dataDirs.size() != 0) //用户指定目录
params.dataDirs = args.dataDirs;
else //默认目录 data/mnist/ 或者 data/samples/mnist/
{
params.dataDirs.push_back("data/mnist/");
params.dataDirs.push_back("data/samples/mnist/");
}
params.prototxtFileName = "mnist.prototxt"; //prototxt文件
params.weightsFileName = "mnist.caffemodel"; //caffemodel文件
params.meanFileName = "mnist_mean.binaryproto"; //均值文件
params.inputTensorNames.push_back("data"); //输入层名称
params.batchSize = 1; //batchSize
params.outputTensorNames.push_back("prob"); //输出层名称
params.dlaCore = args.useDLACore; //

return params;
}

参考资料
1 ​​​TensorRT Installation Guide​​​ ​​https://docs.nvidia.com/deeplearning/sdk/tensorrt-install-guide/index.html​​​

2 ​​TensorRT Documentation​​​ ​​https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/c_api/index.html​​

3 ​​Best Practices For TensorRT Performance​​​ ​​https://docs.nvidia.com/deeplearning/sdk/tensorrt-best-practices/index.html​​​

4 ​​TensorRT Developer Guide​​​ ​​https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html​​​

5 ​​TensorRT API​​​ ​​https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/index.html​​​

6 ​​Samples Support Guide​​​ ​​https://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html​​​

7 ​​“Hello World” For TensorRT​​​​https://docs.nvidia.com/deeplearning/sdk/tensorrt-sample-support-guide/index.html#mnist_sample​​


举报

相关推荐

0 条评论