/* // Definition for a Node. class Node { public: int val = NULL; Node* prev = NULL; Node* next = NULL; Node* child = NULL; Node() {} Node(int _val, Node* _prev, Node* _next, Node* _child) { val = _val; prev = _prev; next = _next; child = _child; } }; */ class Solution { public: Node* flatten(Node* head) { if (head == NULL) return head; flatten2(head); return head; } Node* flatten2(Node* head) { // flatten and return tail Node* ret = head; while (head) { ret = head; if (head->child) { Node* tail = flatten2(head->child); tail->next = head->next; if (tail->next) tail->next->prev = tail; head->next = head->child; head->next->prev = head; head->child = NULL; head = tail->next; ret = tail; } else { head = head->next; } } return ret; } };