TVM的Object类是很多类的基类,详细的分析资料可以参考
深入理解TVM:Object家族 - 知乎
深入理解TVM:Object家族(二) - 知乎
TVM源码品读:万物基石——Object类(1) - 知乎
TVM源码品读:万物基石——Object(2) - 知乎
在阅读TVM C++代码的时候,有很多Object的派生类的类型转换需要追溯到Object/ObjectPtr/ObjectRef,所以这里着重分析三者之间的关系。我们可以只保留三者的包含关系代码:
class TVM_DLL Object {
public:
...
protected:
...
private:
...
friend class ObjectPtr;
...
};
template <typename T>
class ObjectPtr {
public:
...
private:
Object* data_{nullptr};
...
friend class Object;
friend class ObjectRef;
...
};
class ObjectRef {
public:
...
protected:
ObjectPtr<Object> data_;
};
从上面的代码可以看到,ObjectPtr的数据成员data_是一个Object指针,ObjectRef的数据成员data_是一个 ObjectPtr实例。
这里ObjectPtr可以通过自己的私有成员data_操作对应的Object实例,那么ObjectPtr想操作对应的ObjectRef怎么办呢?ObjectPtr的定义中申明了一个友元函数GetRef:
template <typename T>
class ObjectPtr {
public:
...
private:
Object* data_{nullptr};
...
friend class Object;
friend class ObjectRef;
...
template <typename RelayRefType, typename ObjType>
friend RelayRefType GetRef(const ObjType* ptr);
...
};
template <typename RefType, typename ObjType>
inline RefType GetRef(const ObjType* ptr) {
static_assert(std::is_base_of<typename RefType::ContainerType, ObjType>::value,
"Can only cast to the ref of same container type");
if (!RefType::_type_is_nullable) {
ICHECK(ptr != nullptr);
}
return RefType(ObjectPtr<Object>(const_cast<Object*>(static_cast<const Object*>(ptr))));
}
这里我们忽略GetRef函数一开始的检查,只看最后return的那一句。
const_cast<Object*>(static_cast<const Object*>(ptr)) 是将一个ObjType类型的指针强转为Object类型并强制去掉const属性。在TVM的Object家族中,所有以Node为结尾的类名都是继承自Object, 不以Node结尾的类名都是继承自ObjectRef。所以某个以Node结尾的类实例,是可以强转为Object*类型的。
ObjectPtr<object>(xxx)这个就调用了ObjectPtr的构造函数生成一个ObjectPtr实例:
explicit ObjectPtr(Object* data) : data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
IncRef是增加Object的引用次数,这里不细究
接下来RefType(xxx)是调用RefType的构造函数,参数为 ObjectPtr类型,生成一个RefType类实例。如果这个RefType就是ObjectRef,看下对应的构造函数:
explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {}
这样就生成了一个ObjectRef类型。如果RefType是ObjectRef的子类,并且是ObjType对应的类型(比如IRModule和IRModuleNode),那么就可以由某个类型的指针类型,得到对应的Ref类型了。
这种转换在代码中很多,比如说IRModule::FromExprInContext中:
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
这里expr强转为BaseFuncNode类型,以Node结尾的类型都是继承自Object,所以这里使用GetRef就可以转换为BaseFunc类型(不以Node结尾的都继承自ObjectRef)。
这个代码里面as是ObjectRef的成员方法,用作继承路径上下游之间的转换,as的定义:
template <typename ObjectType>
inline const ObjectType* ObjectRef::as() const {
if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
return static_cast<ObjectType*>(data_.get());
} else {
return nullptr;
}
}
替换下模板类型
inline const BaseFuncNode* ObjectRef::as() const {
if (data_ != nullptr && data_->IsInstance<BaseFuncNode>()) {
return static_cast<BaseFuncNode*>(data_.get());
} else {
return nullptr;
}
}
在IRModule::FromExprInContext中expr参数的类型为RelayExpr类型,而我们从python传到这里的是一个Function类型,class Function继承自BaseFunc, BaseFunc继承自RelayExpr,基类是ObjectRef。
IsInstance方法是通过类型的type index判断一个实例属不属于某种类型或者子类。因为Function是RelayExpr的子类,这个地方返回true。
ObjectRef的get方法返回的是Object*类型:
const Object* get() const { return data_.get(); }
所以as()中的data_.get()是可以返回BaseFuncNode(基类为Object)指针的。这样as方法就实现了一条继承路线上子类的非Node类型到父类的非Node类型的转换。
IRModule::FromExprInContext中这个地方能转成功,是因为我们传入C++的实际上是Function类型,是BaseFunc类型的子类。如果这个地方就真的是一个RelayExpr(BaseFunc的父类),这个地方是转不成功的(返回nullptr)。