一、简介
深度学习部署过程中,有时候需要部署多个模型,每个模型的接口相近
,对于常规C++代码,每个模型对应一个类,调用过程中需要手动
对每个类进行操作,对于程序员来讲,这种不智能的方法难以接收,学过java, python的人应该对反射机制
有所了解,无奈C++暂还不支持这种方式,实属C++的一个缺陷吧!
但方法总比困难多,反射机制不支持,用别的方式来达到类似的目标,主要采用设计模型中的单例+工厂模式
,实现字符串创建类
的方式。此外,由于多种模型接口相似,可以使用基类写法,其他类继承该基类
,实现自动操作。
二、原理分析
反射是程序可以访问、检测和修改它本身状态或行为的一种能力,简单地说,通过字符串来创建类。
第一步:创建单例模板,用于单例工厂的创建
// 单例类模板
template<typename T>
class Singleton
{
public:
static T* GetInstance()
{
static T instance;
return &instance;
}
Singleton(T&&) = delete;
Singleton(const T&) = delete;
void operator= (const T&) = delete;
protected:
Singleton() = default;
virtual ~Singleton() = default;
};
第二步:定义函数指针类型:用于指向创建类实例的回调函数
using CreateObjectFunc = function<void*()>;
第三步:创建工厂类实现类与字符串的映射关系
// 创建对象的回调函数
struct CreateObjectFuncClass {
explicit CreateObjectFuncClass(CreateObjectFunc func) : create_func(func) {}
CreateObjectFunc create_func;
};
// Object工厂类
class ObjectFactory : public Singleton<ObjectFactory> {
public:
// 返回void *减少了代码的耦合
// 提供给外部注册以及类创建
void* CreateObject(const string& class_name) {
CreateObjectFunc createobj = nullptr;
if (create_funcs_.find(class_name) != create_funcs_.end())
createobj = create_funcs_.find(class_name)->second->create_func;
if (createobj == nullptr)
return nullptr;
// 调用函数指针指向的函数 调用REGISTER_CLASS中宏的绑定函数,也就是运行new className代码
return createobj();
}
// 保存类名字符串到类对象构造函数指针的映射
void RegisterObject(const string& class_name, CreateObjectFunc func) {
auto it = create_funcs_.find(class_name);
if (it != create_funcs_.end())
create_funcs_[class_name]->create_func = func;
else
create_funcs_.emplace(class_name, new CreateObjectFuncClass(func));
}
~ObjectFactory() {
for (auto it : create_funcs_)
{
if (it.second != nullptr)
{
delete it.second;
it.second = nullptr;
}
}
create_funcs_.clear();
}
private:
// 缓存类名和生成类实例函数指针的map
unordered_map<string, CreateObjectFuncClass* > create_funcs_;
};
第四步:宏定义,方便注册
#define REGISTERCLASS(className) \
class className##Helper { \
public: \
className##Helper() \
{ \
ObjectFactory::GetInstance()->RegisterObject(#className, []() \
{ \
auto* obj = new className(); \
// 这个可以指定默认的执行的函数
// obj->SetModelName(#className); \
return obj; \
}); \
} \
}; \
className##Helper g_##className##_helper;// 初始化一个helper的全局变量,执行构造函数中的RegisterObject执行。
- 备注
- 宏定义中
#:转为字符串
,##:连接两个字符串
,\:代码
换行符
- 宏定义中
第五步:定义各个模型基类,其他模型都继承该基类
class BasicNet{
public:
//构造函数至少要这个,因为register时候使用
BasicNet() {}
bool LoadModel(const std::string modelPath){
//code
return true;
}
void SetModelName(std::string modelName) {_modelNname=modelName;}
virtual ~BasicNet(){}
std::string GetModelName() const {return _modelNname;}
protected:
std::string _modelNname;
};
class AlexNet:public BasicNet{
public:
~AlexNet(){}
};
class LeNet:public BasicNet{
public:
~LeNet(){}
};
第六步:测试结果
int main() {
REGISTERCLASS(AlexNet)
REGISTERCLASS(LeNet)
std::vector<std::string> models{"AlexNet","LeNet"};
for(auto model:models){
auto alexNet = (BasicNet*)ObjectFactory::GetInstance()->CreateObject(model);
alexNet->SetModelName(model);
std::string name = alexNet->GetModelName();
cout << name.c_str() << endl;
delete alexNet;
}
return 0;
}
///////////////////// 结果
AlexNet
LeNet
三、代码详解
#include <iostream>
#include <unordered_map>
#include <functional>
#include <vector>
using namespace std;
// 单例类模板
template<typename T>
class Singleton
{
public:
static T* GetInstance()
{
static T instance;
return &instance;
}
Singleton(T&&) = delete;
Singleton(const T&) = delete;
void operator= (const T&) = delete;
protected:
Singleton() = default;
virtual ~Singleton() = default;
};
using CreateObjectFunc = function<void*()>;
// 创建对象的回调函数
struct CreateObjectFuncClass {
explicit CreateObjectFuncClass(CreateObjectFunc func) : create_func(func) {}
CreateObjectFunc create_func;
};
// Object工厂类
class ObjectFactory : public Singleton<ObjectFactory> {
public:
// 返回void *减少了代码的耦合
// 提供给外部注册以及类创建
void* CreateObject(const string& class_name) {
CreateObjectFunc createobj = nullptr;
if (create_funcs_.find(class_name) != create_funcs_.end())
createobj = create_funcs_.find(class_name)->second->create_func;
if (createobj == nullptr)
return nullptr;
// 调用函数指针指向的函数 调用REGISTER_CLASS中宏的绑定函数,也就是运行new className代码
return createobj();
}
// 保存类名字符串到类对象构造函数指针的映射
void RegisterObject(const string& class_name, CreateObjectFunc func) {
auto it = create_funcs_.find(class_name);
if (it != create_funcs_.end())
create_funcs_[class_name]->create_func = func;
else
create_funcs_.emplace(class_name, new CreateObjectFuncClass(func));
}
~ObjectFactory() {
for (auto it : create_funcs_)
{
if (it.second != nullptr)
{
delete it.second;
it.second = nullptr;
}
}
create_funcs_.clear();
}
private:
// 缓存类名和生成类实例函数指针的map
unordered_map<string, CreateObjectFuncClass* > create_funcs_;
};
#define REGISTERCLASS(className) \
class className##Helper { \
public: \
className##Helper() \
{ \
ObjectFactory::GetInstance()->RegisterObject(#className, []() \
{ \
auto* obj = new className(); \
return obj; \
}); \
} \
}; \
className##Helper g_##className##_helper;// 初始化一个helper的全局变量,执行构造函数中的RegisterObject执行。
class BasicNet{
public:
BasicNet() {}
bool LoadModel(const std::string modelPath){
//code
return true;
}
void SetModelName(std::string modelName) {_modelNname=modelName;}
virtual ~BasicNet(){}
std::string GetModelName() const {return _modelNname;}
protected:
std::string _modelNname;
};
class AlexNet:public BasicNet{
public:
~AlexNet(){}
};
class LeNet:public BasicNet{
public:
~LeNet(){}
};
int main() {
REGISTERCLASS(AlexNet)
REGISTERCLASS(LeNet)
std::vector<std::string> models{"AlexNet","LeNet"};
for(auto model:models){
auto alexNet = (BasicNet*)ObjectFactory::GetInstance()->CreateObject(model);
alexNet->SetModelName(model);
std::string name = alexNet->GetModelName();
cout << name.c_str() << endl;
delete alexNet;
}
return 0;
}