# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

import bpy

from typing import List, Set, Tuple, Union
from .blender import is_400_or_gt


class NodeBuilder:
    """
    Base class to help building node trees.
    """
    def __init__(self, space_x: int = 100, space_y: int = 250):
        self.space_x, self.space_y = space_x, space_y
        self.tree = None # type: bpy.types.NodeTree


    def add(self, type: str):
        """
        Add node with given type to tree
        """
        return self.tree.nodes.new(type)
    

    def tree_inputs(self, tree: bpy.types.NodeTree = None) -> List[bpy.types.NodeSocket]:
        """
        Return the input nodes of a tree.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            return [ io for io in tree.interface.items_tree if io.item_type == 'SOCKET' and io.in_out == 'INPUT' ]
        else:
            return tree.inputs
        

    def tree_add_input(self, name: str, sub_type: str, tree: bpy.types.NodeTree = None) -> bpy.types.NodeSocket:
        """
        Create the respective input and return it.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            # Some special cases.
            subtype = ''
            if sub_type == 'FloatFactor':
                sub_type = 'Float'
                subtype = 'FACTOR'
            n = tree.interface.new_socket(name = name, in_out = 'INPUT', socket_type = 'NodeSocket%s' % sub_type)
            if subtype:
                n.subtype = subtype
            return n
        else:
            return tree.inputs.new('NodeSocket%s' % sub_type, name)
        

    def tree_remove_input(self, name: str, tree: bpy.types.NodeTree = None):
        """
        Remove specific input if exists.
        """
        tree = tree if tree else self.tree
        socket = self.tree_get_input(name, tree)
        if socket:
            if is_400_or_gt():
                tree.interface.remove(socket)
            else:
                tree.inputs.remove(socket)        
        

    def tree_has_input(self, name: str, tree: bpy.types.NodeTree = None) -> bool:
        """
        Check if tree has input with given name.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            for io in tree.interface.items_tree:
                if io.item_type == 'SOCKET' and io.in_out == 'INPUT' and io.name == name:
                    return True
            return False
        else:
            return name in tree.inputs
        

    def tree_get_input(self, name: str, tree: bpy.types.NodeTree = None) -> Union[bpy.types.NodeSocket, None]:
        """
        Return socket by name or None.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            for io in tree.interface.items_tree:
                if io.item_type == 'SOCKET' and io.in_out == 'INPUT' and io.name == name:
                    return io
            return None
        else:
            if name in tree.inputs:
                return tree.inputs[name]
            return None
        

    def tree_input_index(self, name: str, tree: bpy.types.NodeTree = None) -> int:
        """
        Return index of input in node tree.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            for n, io in enumerate(tree.interface.items_tree):
                if io.item_type == 'SOCKET' and io.in_out == 'INPUT' and io.name == name:
                    return n
            return -1
        else:
            return tree.inputs.find(name)
        
    
    def tree_move_input(self, socket: bpy.types.NodeSocket, index: int, tree: bpy.types.NodeTree = None):
        """
        Move tree socket to specific index.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            tree.interface.move(socket, index)
        else:
            for i, s in enumerate(tree.inputs):
                if s == socket:
                    tree.inputs.move(i, index)
                    break

        
    def tree_outputs(self, tree: bpy.types.NodeTree = None) -> List[bpy.types.NodeSocket]:
        """
        Return the output nodes of a tree.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            return [ io for io in tree.interface.items_tree if io.item_type == 'SOCKET' and io.in_out == 'OUTPUT' ]
        else:
            return tree.outputs
        

    def tree_add_output(self, name: str, sub_type: str, tree: bpy.types.NodeTree = None) -> bpy.types.NodeSocket:
        """
        Create the respective output and return it.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            return tree.interface.new_socket(name = name, in_out = 'OUTPUT', socket_type = 'NodeSocket%s' % sub_type)
        else:
            return tree.outputs.new('NodeSocket%s' % sub_type, name)        
        

    def tree_remove_output(self, name: str, tree: bpy.types.NodeTree = None):
        """
        Remove specific output if exists.
        """
        tree = tree if tree else self.tree
        socket = self.tree_get_output(name, tree)
        if socket:
            if is_400_or_gt():
                tree.interface.remove(socket)
            else:
                tree.outputs.remove(socket)        
        

    def tree_has_output(self, name: str, tree: bpy.types.NodeTree = None) -> bool:
        """
        Check if tree has input with given name.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            for io in tree.interface.items_tree:
                if io.item_type == 'SOCKET' and io.in_out == 'OUTPUT' and io.name == name:
                    return True
            return False
        else:
            return name in tree.outputs
        

    def tree_get_output(self, name: str, tree: bpy.types.NodeTree = None) -> Union[bpy.types.NodeSocket, None]:
        """
        Return socket by name or None.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            for io in tree.interface.items_tree:
                if io.item_type == 'SOCKET' and io.in_out == 'OUTPUT' and io.name == name:
                    return io
            return None
        else:
            if name in tree.outputs:
                return tree.outputs[name]
            return None
        

    def tree_output_index(self, name: str, tree: bpy.types.NodeTree = None) -> int:
        """
        Return index of input in node tree.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            for n, io in enumerate(tree.interface.items_tree):
                if io.item_type == 'SOCKET' and io.in_out == 'OUTPUT' and io.name == name:
                    return n
            return -1
        else:
            return tree.outputs.find(name)
        

    def tree_move_output(self, socket: bpy.types.NodeSocket, index: int, tree: bpy.types.NodeTree = None):
        """
        Move tree socket to specific index.
        """
        tree = tree if tree else self.tree
        if is_400_or_gt():
            tree.interface.move(socket, index)
        else:
            for i, s in enumerate(tree.outputs):
                if s == socket:
                    tree.outputs.move(i, index)
                    break


    def _organize(self):
        """
        Assumes all nodes are placed in correct arrangement (left <-> right, bottom <-> top)
        and move them to meet the given spacing requirements.
        """
        columns = {}
        for n in self.tree.nodes:
            if n.location.x in columns:
                columns[n.location.x].append(n)
            else:
                columns[n.location.x] = [n, ]

        cx = 0
        for k in sorted(columns):
            width = 0
            for n in columns[k]: # type: bpy.types.Node
                width = max(width, n.dimensions[0], n.width)

            cy = 0
            for n in columns[k]: # type: bpy.types.Node
                n.location = (cx, cy)
                cy -= (max(n.dimensions[1], n.height) + self.space_y)
            half_height = -(cy + self.space_y) / 2

            for n in columns[k]: # type: bpy.types.Node
                n.location[1] = n.location[1] + half_height

            cx += width + self.space_x

        half_width = (cx - self.space_x) / 2
        for n in self.tree.nodes:
            n.location[0] -= half_width


    def _arrange_left_of(self, current: bpy.types.Node):
        """
        Called recursivly from arrange, to find all nodes connected to "current"
        and adjusts their positions (Old algorithm).
        """
        cy = current.location.y

        for s in self.tree_inputs(tree = current):
            if len(s.links) > 0:
                other = s.links[0].from_socket.node
                other.location.x = min(other.location.x, current.location.x - 100)
                other.location.y = cy
                cy += 100
                self._arrange_left_of(other)


    def _arrange_right_of(self):
        """
        New arrange algorithm, the old one took too long with many connections.
        """
        # Create list of all nodes including including all nodes connected to 
        # its output sockets.
        relations = [] # type: List[Tuple[bpy.types.Node, Set(bpy.types.Node)]]
        for n in self.tree.nodes:
            right = set()
            for s in n.outputs:
                for l in s.links:
                    right.add(l.to_socket.node)
            relations.append((n, right))

        # Create XY table and put all nodes at X=0 (X == table index).
        tbl = [relations,]

        # Order algo. Loop through every X, check each node at this
        # position. If a node from its right is also at this position,
        # move it left (by putting into a new X -> nl)
        i = 0
        while i < len(tbl):
            nl = [] # Possible new left row.

            # Work with copied list, as we change tbl[i].
            tmp = [*tbl[i]] 
            for rel in tmp:
                # Check if any right node (r) is in current X,
                # set shift to True in this case.
                shift = False
                n, r = rel
                for rn in r:
                    for prel in tmp:
                        if rn == prel[0]:
                            shift = True
                            break # We can stop
                    if shift:
                        break # We can stop

                # If we have to move the node to the left,
                # do it here.
                if shift:
                    # Add to new left side row ..
                    nl.append(rel)
                    # .. and remove from current.
                    tbl[i].remove(rel)

            # If there's a new element in the new left row,
            # enhance the table (at the end, in reality, this should be start,
            # but we reverse this below).
            if nl:
                tbl.append(nl)
            i += 1

        # Apply table position of every node into node location.
        for x in range(len(tbl)):
            c = tbl[x]
            for y in range(len(c)):
                n, _ = c[y]
                n.location.x = -x * 100 # Reverse left<->right.
                n.location.y = y * 100


    def arrange(self, most_right: bpy.types.Node = None):
        """
        Moves all nodes connected to most_right from left to right, based
        on the linking information, all input nodes are left of their target.
        In y direction, nodes are placed in order of link information, top to bottom.
        """
        for n in self.tree.nodes:
            n.location = (self.space_x, 0)

        if not most_right:
            for i in self.tree.nodes:
                if i.bl_idname == 'NodeGroupOutput' or \
                    i.bl_idname == 'ShaderNodeOutputMaterial':
                    most_right = i
                    break

        if most_right:
            most_right.location = (0, 0)
            #self._arrange_left_of(most_right)
            self._arrange_right_of()
            self._organize()
