0
点赞
收藏
分享

微信扫一扫

C++实现字符串创建类对象

一、简介

深度学习部署过程中,有时候需要部署多个模型,每个模型的接口相近,对于常规C++代码,每个模型对应一个类,调用过程中需要手动对每个类进行操作,对于程序员来讲,这种不智能的方法难以接收,学过java, python的人应该对反射机制有所了解,无奈C++暂还不支持这种方式,实属C++的一个缺陷吧!

但方法总比困难多,反射机制不支持,用别的方式来达到类似的目标,主要采用设计模型中的单例+工厂模式,实现字符串创建类的方式。此外,由于多种模型接口相似,可以使用基类写法,其他类继承该基类,实现自动操作。

二、原理分析

反射是程序可以访问、检测和修改它本身状态或行为的一种能力,简单地说,通过字符串来创建类。

第一步:创建单例模板,用于单例工厂的创建

// 单例类模板
template<typename T>
class Singleton
{
public:
    static T* GetInstance()
    {
        static T instance;
        return &instance;
    }

    Singleton(T&&) = delete;
    Singleton(const T&) = delete;
    void operator= (const T&) = delete;

protected:
    Singleton() = default;
    virtual ~Singleton() = default;
};

第二步:定义函数指针类型:用于指向创建类实例的回调函数

using CreateObjectFunc = function<void*()>;

第三步:创建工厂类实现类与字符串的映射关系

// 创建对象的回调函数
struct CreateObjectFuncClass {
    explicit CreateObjectFuncClass(CreateObjectFunc func) : create_func(func) {}
    CreateObjectFunc create_func;
};

// Object工厂类
class ObjectFactory : public Singleton<ObjectFactory> {
public:
    // 返回void *减少了代码的耦合
        // 提供给外部注册以及类创建
    void* CreateObject(const string& class_name) {
        CreateObjectFunc createobj = nullptr;

        if (create_funcs_.find(class_name) != create_funcs_.end())
            createobj = create_funcs_.find(class_name)->second->create_func;

        if (createobj == nullptr)
            return nullptr;

        // 调用函数指针指向的函数 调用REGISTER_CLASS中宏的绑定函数,也就是运行new className代码
        return createobj();
    }

        // 保存类名字符串到类对象构造函数指针的映射
    void RegisterObject(const string& class_name, CreateObjectFunc func) {
        auto it = create_funcs_.find(class_name);
        if (it != create_funcs_.end())
            create_funcs_[class_name]->create_func = func;
        else
            create_funcs_.emplace(class_name, new CreateObjectFuncClass(func));
    }

    ~ObjectFactory() {
        for (auto it : create_funcs_)
        {
            if (it.second != nullptr)
            {
                delete it.second;
                it.second = nullptr;
            }
        }
        create_funcs_.clear();
    }

private:
    // 缓存类名和生成类实例函数指针的map
    unordered_map<string, CreateObjectFuncClass* > create_funcs_;
};

第四步:宏定义,方便注册

#define REGISTERCLASS(className) \
class className##Helper { \
public: \
    className##Helper() \
    { \
        ObjectFactory::GetInstance()->RegisterObject(#className, []() \
        { \
            auto* obj = new className(); \
                        // 这个可以指定默认的执行的函数
            // obj->SetModelName(#className); \
            return obj; \
        }); \
    } \
}; \
className##Helper g_##className##_helper;// 初始化一个helper的全局变量,执行构造函数中的RegisterObject执行。 
  • 备注
    1. 宏定义中 #:转为字符串##:连接两个字符串\:代码换行符

第五步:定义各个模型基类,其他模型都继承该基类

class BasicNet{
public:
        //构造函数至少要这个,因为register时候使用
        BasicNet() {}
        bool LoadModel(const std::string modelPath){
            //code
            return true;
        }
        void SetModelName(std::string modelName) {_modelNname=modelName;}
        virtual ~BasicNet(){}
        std::string GetModelName() const {return _modelNname;}
protected:
        std::string _modelNname;
};
class AlexNet:public BasicNet{
public:
        ~AlexNet(){}
};
class LeNet:public BasicNet{
public:
        ~LeNet(){}
};

第六步:测试结果

int main() {
    REGISTERCLASS(AlexNet)
    REGISTERCLASS(LeNet)
    std::vector<std::string> models{"AlexNet","LeNet"};
    for(auto model:models){
        auto alexNet = (BasicNet*)ObjectFactory::GetInstance()->CreateObject(model);
        alexNet->SetModelName(model);
        std::string name = alexNet->GetModelName();
        cout << name.c_str() << endl;
        delete alexNet;
    }
    return 0;
}

///////////////////// 结果
AlexNet
LeNet

三、代码详解

#include <iostream>
#include <unordered_map>
#include <functional>
#include <vector>
using namespace std;

// 单例类模板
template<typename T>
class Singleton
{
public:
    static T* GetInstance()
    {
        static T instance;
        return &instance;
    }

    Singleton(T&&) = delete;
    Singleton(const T&) = delete;
    void operator= (const T&) = delete;

protected:
    Singleton() = default;
    virtual ~Singleton() = default;
};
using CreateObjectFunc = function<void*()>;
// 创建对象的回调函数
struct CreateObjectFuncClass {
    explicit CreateObjectFuncClass(CreateObjectFunc func) : create_func(func) {}
    CreateObjectFunc create_func;
};

// Object工厂类
class ObjectFactory : public Singleton<ObjectFactory> {
public:
    // 返回void *减少了代码的耦合
        // 提供给外部注册以及类创建
    void* CreateObject(const string& class_name) {
        CreateObjectFunc createobj = nullptr;

        if (create_funcs_.find(class_name) != create_funcs_.end())
            createobj = create_funcs_.find(class_name)->second->create_func;

        if (createobj == nullptr)
            return nullptr;

        // 调用函数指针指向的函数 调用REGISTER_CLASS中宏的绑定函数,也就是运行new className代码
        return createobj();
    }

        // 保存类名字符串到类对象构造函数指针的映射
    void RegisterObject(const string& class_name, CreateObjectFunc func) {
        auto it = create_funcs_.find(class_name);
        if (it != create_funcs_.end())
            create_funcs_[class_name]->create_func = func;
        else
            create_funcs_.emplace(class_name, new CreateObjectFuncClass(func));
    }

    ~ObjectFactory() {
        for (auto it : create_funcs_)
        {
            if (it.second != nullptr)
            {
                delete it.second;
                it.second = nullptr;
            }
        }
        create_funcs_.clear();
    }

private:
    // 缓存类名和生成类实例函数指针的map
    unordered_map<string, CreateObjectFuncClass* > create_funcs_;
};
#define REGISTERCLASS(className) \
class className##Helper { \
public: \
    className##Helper() \
    { \
        ObjectFactory::GetInstance()->RegisterObject(#className, []() \
        { \
            auto* obj = new className(); \
            return obj; \
        }); \
    } \
}; \
className##Helper g_##className##_helper;// 初始化一个helper的全局变量,执行构造函数中的RegisterObject执行。 

class BasicNet{
public:
        BasicNet() {}
        bool LoadModel(const std::string modelPath){
            //code
            return true;
        }
        void SetModelName(std::string modelName) {_modelNname=modelName;}
        virtual ~BasicNet(){}
        std::string GetModelName() const {return _modelNname;}
protected:
        std::string _modelNname;
};
class AlexNet:public BasicNet{
public:
        ~AlexNet(){}
};
class LeNet:public BasicNet{
public:
        ~LeNet(){}
};

int main() {
    REGISTERCLASS(AlexNet)
    REGISTERCLASS(LeNet)
    std::vector<std::string> models{"AlexNet","LeNet"};
    for(auto model:models){
        auto alexNet = (BasicNet*)ObjectFactory::GetInstance()->CreateObject(model);
        alexNet->SetModelName(model);
        std::string name = alexNet->GetModelName();
        cout << name.c_str() << endl;
        delete alexNet;
    }

    return 0;
}

四、参考

举报

相关推荐

c++字符串

C++字符串拼接

C++的字符串

C++字符串详解

c++ 字符串插入

C++分割字符串

字符串压缩(C++)

0 条评论