0
点赞
收藏
分享

微信扫一扫

TVM C++代码的Object ObjectPtr ObjectRef关系

猎书客er 2022-02-18 阅读 147

TVM的Object类是很多类的基类,详细的分析资料可以参考

深入理解TVM:Object家族 - 知乎

深入理解TVM:Object家族(二) - 知乎

TVM源码品读:万物基石——Object类(1) - 知乎

TVM源码品读:万物基石——Object(2) - 知乎

在阅读TVM C++代码的时候,有很多Object的派生类的类型转换需要追溯到Object/ObjectPtr/ObjectRef,所以这里着重分析三者之间的关系。我们可以只保留三者的包含关系代码:

class TVM_DLL Object {
public:
 
    ...
protected:
 
    ...
private:
    ...
    friend class ObjectPtr;
    ...
};


template <typename T>
class ObjectPtr {
public:
    ...

private:
    Object* data_{nullptr};
    ...
    friend class Object;
    friend class ObjectRef;
    ...
};


class ObjectRef {
public:
    ...  

protected:

    ObjectPtr<Object> data_;
  
};

 从上面的代码可以看到,ObjectPtr的数据成员data_是一个Object指针,ObjectRef的数据成员data_是一个 ObjectPtr实例。

这里ObjectPtr可以通过自己的私有成员data_操作对应的Object实例,那么ObjectPtr想操作对应的ObjectRef怎么办呢?ObjectPtr的定义中申明了一个友元函数GetRef:

template <typename T>
class ObjectPtr {
public:
    ...

private:
    Object* data_{nullptr};
    ...
    friend class Object;
    friend class ObjectRef;
    ...
    template <typename RelayRefType, typename ObjType>
    friend RelayRefType GetRef(const ObjType* ptr);
    ...
};

template <typename RefType, typename ObjType>
inline RefType GetRef(const ObjType* ptr) {
    static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
                "Can only cast to the ref of same container type");
    if (!RefType::_type_is_nullable) {
      ICHECK(ptr != nullptr);
    }
    return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}

这里我们忽略GetRef函数一开始的检查,只看最后return的那一句。

const_cast<Object*>(static_cast<const Object*>(ptr)) 是将一个ObjType类型的指针强转为Object类型并强制去掉const属性。在TVM的Object家族中,所有以Node为结尾的类名都是继承自Object, 不以Node结尾的类名都是继承自ObjectRef。所以某个以Node结尾的类实例,是可以强转为Object*类型的。

ObjectPtr<object>(xxx)这个就调用了ObjectPtr的构造函数生成一个ObjectPtr实例:

  explicit ObjectPtr(Object* data) : data_(data) {
    if (data != nullptr) {
      data_->IncRef();
    }
  }

 IncRef是增加Object的引用次数,这里不细究

接下来RefType(xxx)是调用RefType的构造函数,参数为 ObjectPtr类型,生成一个RefType类实例。如果这个RefType就是ObjectRef,看下对应的构造函数:

explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}

这样就生成了一个ObjectRef类型。如果RefType是ObjectRef的子类,并且是ObjType对应的类型(比如IRModule和IRModuleNode),那么就可以由某个类型的指针类型,得到对应的Ref类型了。

这种转换在代码中很多,比如说IRModule::FromExprInContext中:

if (auto* func_node = expr.as<BaseFuncNode>()) {
    func = GetRef<BaseFunc>(func_node);

这里expr强转为BaseFuncNode类型,以Node结尾的类型都是继承自Object,所以这里使用GetRef就可以转换为BaseFunc类型(不以Node结尾的都继承自ObjectRef)。

这个代码里面as是ObjectRef的成员方法,用作继承路径上下游之间的转换,as的定义:

template <typename ObjectType>
inline const ObjectType* ObjectRef::as() const {
  if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
    return static_cast<ObjectType*>(data_.get());
  } else {
    return nullptr;
  }
}

替换下模板类型

inline const BaseFuncNode* ObjectRef::as() const {
  if (data_ != nullptr && data_->IsInstance<BaseFuncNode>()) {
    return static_cast<BaseFuncNode*>(data_.get());
  } else {
    return nullptr;
  }
}

在IRModule::FromExprInContext中expr参数的类型为RelayExpr类型,而我们从python传到这里的是一个Function类型,class Function继承自BaseFunc, BaseFunc继承自RelayExpr,基类是ObjectRef。

IsInstance方法是通过类型的type index判断一个实例属不属于某种类型或者子类。因为Function是RelayExpr的子类,这个地方返回true。

ObjectRef的get方法返回的是Object*类型:

 const Object* get() const { return data_.get(); }

所以as()中的data_.get()是可以返回BaseFuncNode(基类为Object)指针的。这样as方法就实现了一条继承路线上子类的非Node类型到父类的非Node类型的转换。

IRModule::FromExprInContext中这个地方能转成功,是因为我们传入C++的实际上是Function类型,是BaseFunc类型的子类。如果这个地方就真的是一个RelayExpr(BaseFunc的父类),这个地方是转不成功的(返回nullptr)。

举报

相关推荐

0 条评论