# Copyright (C) 2023 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

import bpy, gpu

from gpu_extras.batch import batch_for_shader
from typing import List, Tuple, Union

from ..utils.blender import is_400_or_gt
from ..preferences import PreferencesPanel
from ..registries.texture_registry import TextureRegistry

# TODO:
# Preview Lag.

__bg_vs = """
    uniform mat4 ModelViewProjectionMatrix;
    in vec2 pos;

    void main() {
        gl_Position = ModelViewProjectionMatrix * vec4(pos.xy, 0.0f, 1.0f);
        gl_Position.z = 1.0;
    }
"""

__bg_fs = """
    uniform vec4 color1;
    uniform vec4 color2;
    uniform vec2 corner;
    uniform vec2 size;
    out vec4 FragColor;

    float roundedBoxSDF(vec2 CenterPosition, vec2 Size, float Radius) {
        return length(max(abs(CenterPosition)-Size+Radius,0.0))-Radius;
    }

    void main() {
        float radius = 8;
        float edgeSoftness = 1.0;
        int checkerSize = 16;
        
        vec2 center = size * 0.5;
        float distance = roundedBoxSDF(gl_FragCoord.xy - center - corner, center, radius);
        float smoothedAlpha = 1.0f-smoothstep(0.0f, edgeSoftness * 2.0f,distance);
        vec4 shape = mix(vec4(0.0f, 1.0f, 0.0f, 0.0f), vec4(1.0f, 1.0f, 1.0f, smoothedAlpha), smoothedAlpha);
        
        vec2 checkerboard_coords = (gl_FragCoord.xy - corner) / float(checkerSize);
        int checkerboard_pattern = (int(checkerboard_coords.x) % 2 + int(checkerboard_coords.y) % 2) % 2;
        vec4 final_color = mix(color1, color2, float(checkerboard_pattern));

        vec4 o = final_color * shape;
        if (o.a < 0.5)
            discard;
        else
            FragColor = o;
    }
"""
__bg_shader = gpu.types.GPUShader(__bg_vs, __bg_fs)
# Padding between images and to bg border.
__padding = 2
# Dist to node.
__dist = 8

__node_enhancer_handle = None

def draw_bg(x: int, y: int, w: int, h: int):
    """
    Draw a checkerboard rectangle with round corners.
    """
    vertices = ( (x, y), (x, y + h), (x + w, y + h), (x + w, y) )
    shader = __bg_shader
    batch = batch_for_shader(shader, 'TRI_FAN', { "pos": vertices })
    shader.bind()
    shader.uniform_float("color1", (0.1, 0.1, 0.1, 1.0))
    shader.uniform_float("color2", (0.12, 0.12, 0.12, 1.0))
    shader.uniform_float("corner", (x, y))
    shader.uniform_float("size", (w, h))
    batch.draw(shader)


def draw_rect(x: int, y: int, w: int, h: int, image: Union[bpy.types.Image, int]):
    """
    Draws image in the given rect.
    """
    vertices = ( (x, y), (x, y + h), (x + w, y + h), (x + w, y) )
    # Calculate the vertical size to use from texture (V).
    # Depends on rect size ..
    hs = h / w
    # .. and source image size.
    if isinstance(image, bpy.types.Image):
        try:
            ia = image.size[0] / image.size[1]
            hs *= ia
        except:
            pass # No error

    # The final drawing.
    shader = gpu.shader.from_builtin('IMAGE') if is_400_or_gt() else gpu.shader.from_builtin('2D_IMAGE')
    batch = batch_for_shader(
        shader, 
        'TRI_FAN', { 
            "pos": vertices,
            "texCoord": ((0.0, 0.0), (0.0, hs), (1.0, hs), (1.0, 0.0))
        }
    )
    shader.bind()
    if isinstance(image, bpy.types.Image):
        shader.uniform_sampler("image", gpu.texture.from_image(image))
    else:
        shader.uniform_sampler("image", image)
    batch.draw(shader)


def images_of_tree(t: bpy.types.ShaderNodeTree, rec: int) -> List[Union[bpy.types.Image, int]]:
    """
    Search for image nodes in this tree and maybe subtrees.
    """
    # Should we try to render the preview image?
    if PreferencesPanel.get().node_visualizer_snw_preview:
        snw_node_info = t.snw_node_info
        if snw_node_info:
            try:
                entry = TextureRegistry.instance().get_image_by_snw_info(snw_node_info, with_gpu=True)
                if entry:
                    return [entry.gpu, ]
            except Exception:
                pass

    # Find the single textures.
    own, sub = [], []
    for n in t.nodes:
        if isinstance(n, bpy.types.ShaderNodeTexImage) and n.image:
            own.append(n.image)
        if rec - 1 != 0 and isinstance(n, bpy.types.ShaderNodeGroup):
            sub.extend(images_of_tree(n.node_tree, rec - 1))
    return own + sub


def images_of_node(n: bpy.types.Node) -> List[Union[bpy.types.Image, int]]:
    """
    Return list of images that this node should shot on top.
    """
    # Tex image node, just show the image.
    if isinstance(n, bpy.types.ShaderNodeTexImage) and n.image:
        return [n.image, ]
    elif isinstance(n, bpy.types.ShaderNodeGroup):
        # If it is a node group, search the subtree, recursion from preferences into sub groups.
        return images_of_tree(n.node_tree, PreferencesPanel.get().node_visualizer_recursion)
    # Nothing to show on other node.
    return []


def final_location(n: bpy.types.Node) -> Tuple[float, float]:
    """
    Find true location of node, taking parents into account.
    """
    x, y = n.location
    p = n.parent
    while p:
        x += p.location[0]
        y += p.location[1]
        p = p.parent
    
    return (x, y)


def node_visualizer():
    """
    This is called every time when the node tree space is redrawn.
    Everything takes place here.
    """
    # Check if we should be active at all.
    if PreferencesPanel.get().node_visualizer_enabled:
        context = bpy.context
        region = context.region
        ui_scale = context.preferences.system.ui_scale

        # Calculate vec2 from view to region space ..
        def to_region(x: float, y: float) -> Tuple[float, float]:
            return region.view2d.view_to_region(x * ui_scale, y * ui_scale, clip=False)

        # Tree to work on.
        tree = context.space_data.edit_tree # type: bpy.types.NodeTree

        # Some placement calculations ..
        tile_size = PreferencesPanel.get().node_visualizer_size
        dist = __padding + tile_size # Image2Image dist

        # If we have a shader node tree..
        if tree and tree.type == 'SHADER':
            # Loop over all nodes ..
            for n in tree.nodes:
                # Use to to prevent duplicate entries.
                images = set(images_of_node(n)) 
                if images:
                    # Images fit in a single row,a t least 1.
                    images_per_row = max(n.width // tile_size, 1)

                    # .. to center them ..
                    total_width = dist * min(len(images), images_per_row) - __padding
                    inset = (n.width - total_width) / 2

                    # Simply find bounding box ..
                    bl, br, bb, bt = 100000, -100000, 100000, -100000

                    # Build a list of images with its render position.
                    placing = []
                    location = final_location(n)
                    for i, image in enumerate(images):
                        # Grid pos.
                        row = i // images_per_row
                        col = i % images_per_row
                        # Corners in region space.
                        xv, yv = location[0] + col * dist + inset, location[1] + __dist + row * dist
                        rv, tv = xv + tile_size, yv + tile_size
                        x, y = to_region(xv, yv)
                        r, t = to_region(rv, tv)
                        # Place image here ..
                        placing.append((( x, y, r - x, t - y ), image))

                        # Find bounding box.
                        bl, br, bb, bt = min(bl, x), max(br, r), min(bb, y), max(bt, t)


                    # Draw to bg rectangle.
                    draw_bg(
                        bl - __padding, 
                        bb - __padding, 
                        br - bl + 2 * __padding, 
                        bt - bb + 2 * __padding
                    )

                    # Place all images.
                    for pos, image in placing:
                        draw_rect(*pos, image)


def initialize_node_visualizer():
    """
    Called on startup, register the draw handler in node space.
    """
    global __node_enhancer_handle
    __node_enhancer_handle = bpy.types.SpaceNodeEditor.draw_handler_add(
        node_visualizer,
        (),
        'WINDOW',
        'POST_PIXEL'
    )


def deinitialize_node_visualizer():
    """
    Called on shutdown, unregister the draw handler in node space.
    """
    global __node_enhancer_handle
    bpy.types.SpaceNodeEditor.draw_handler_remove(
        __node_enhancer_handle, 
        'WINDOW'
    )
