0
点赞
收藏
分享

微信扫一扫

(Ipython)Matplotlib 中将二叉树可视化

墨春 2022-03-10 阅读 36

最近学习红白二叉树,我想如果把二叉树可视化在操作的时候如果出错会比较容易发现。

在网上搜了一圈只有比较简单的ascii 的代码。

自己用Ipython写了一个,比较适合学生。

PS:算法没有做优化,加上matplotlib本身就慢,不适合较高的树。

效果见图:

代码中只有一个简单的二叉树框架,主要算法是根据映射到矩阵再输出图像。

import matplotlib.pyplot as plt
import matplotlib.lines as mlines


class Node():
    def __init__(self):
        self.is_red = False
        self.left = None
        self.right = None
        self.value = 0
        
    def get_height(self): #比较慢的方法,扫描一遍整个树获取长度
        layers = [self]
        layer_count = 0
        while layers:
            layer_count += 1
            new_list = []
            for node in layers:
                if node.left:
                    new_list.append(node.left)
                if node.right:
                    new_list.append(node.right)
            layers = new_list
        return layer_count
    
    def visualize(self,axis='off'):
        '''
            主要算法:根据二叉树高度创建一个方形的二维矩阵,将节点映射到二维矩阵中,
            遍历二维矩阵并输出图像
        '''

    
        figure, axes = plt.subplots(figsize=(8, 6), dpi=80)
        height = self.get_height()
        width_ = 2**(height-1)
        width = 2 * width_ + 1
        matrix = [[[]for x in range(width)] for y in range(height)]

        matrix[0][width_] = head #put head in the middle position

        for y in range(len(matrix)):
            for x in range(len(matrix[y])):
                node = matrix[y][x]
                if node:
                    x1, y1 = (1/width)*(x+0.5), 1-(1/height)*y-0.2
                    axes.text(x1, y1, str(node.value),color='white',fontsize=FONT_SIZE,fontweight='bold')
                    if node.left:
                        matrix[y+1][x-1] = node.left
                        x2,y2 = (1/width)*(x-0.5),1-(1/height)*(y+1)-0.2
                        line = mlines.Line2D([x1,x2], [y1,y2],zorder= -1)
                        axes.add_line(line)
                    if node.right:
                        matrix[y+1][x+1] = node.right
                        x2,y2 = (1/width)*(x+1.5),1-(1/height)*(y+1)-0.2
                        line = mlines.Line2D([x1,x2], [y1,y2],zorder= -1)
                        axes.add_line(line)
                        
                    cc = plt.Circle(   ((1/width)*(x+0.5), 1-(1/height)*y-0.2 ), 
                                        1/width/2, 
                                        color=('r' if node.is_red else 'black' )) 
                    axes.set_aspect(1) 
                    axes.add_artist(cc,)


        plt.axis(axis)
        plt.show()

def create_empty_tree(): #手动写入了一个二叉树测试用
    global head
    head = Node()
    head.left = Node()
    head.right = Node()
    head.left.left = Node()
    head.left.is_red = True
    head.left.right = Node()
    
create_empty_tree()


FONT_SIZE = 15    #字体大小需要手动调节,节点大小会根据二叉树高度变化
head.visualize()
举报

相关推荐

0 条评论