Given a binary tree, flatten it to a linked list in-place.
For example,
Given
         1
        / \
       2   5
      / \   \
     3   4   6
The flattened tree should look like:
   1
    \
     2
      \
       3
        \
         4
          \
           5
            \
             6
click to show hints.
Hints:
If you notice carefully in the flattened tree, each node's right child points to the next node of a pre-order traversal.
class Solution(object):
    def flatten(self, root):
        """
        :type root: TreeNode
        :rtype: void Do not return anything, modify root in-place instead.
        """
        def dfs(root):
            if not root:
                return root
            
            left = dfs(root.left)
            right = dfs(root.right)
            
            if not left and not right:
                return root
                
            if right is None:
                root.right = root.left
                root.left = None
                return left
                
            if not left:
                return right
                
            tmp = root.right
            root.right = root.left
            root.left = None
            left.right = tmp
            return right
        dfs(root)