# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

import bpy

from typing import Tuple, Any, List

from bpy.types import NodeSocketFloat

from ...utils.node_builder_cycles import NodeBuilderCycles


def _update_default_values(
    nodes: List[bpy.types.ShaderNode], 
    tree: bpy.types.ShaderNodeTree, 
    socket: str, 
    value: Any
    ):
    """
    Inserting sockets, e.g. AO, Displacement, ... does set the internal
    default value for the input, but not the value at the instanced node groups.
    So with have to traverse over all nodes in all node trees of every material
    recursively and if there's a node group with our tree, update the input.
    """
    for n in nodes:
        if n.bl_idname == 'ShaderNodeGroup':
            if n.node_tree == tree:
                n.inputs[socket].default_value = value
            else:
                _update_default_values(n.node_tree.nodes, tree, socket, value)


def _update_all_default_values(
    tree: bpy.types.ShaderNodeTree, 
    socket: str, 
    value: Any
    ):
    """
    See function above.
    """
    for m in bpy.data.materials:
        if hasattr(m, 'node_tree') and m.node_tree:
            _update_default_values(m.node_tree.nodes, tree, socket, value)                


def add_power_weight(nb: NodeBuilderCycles, value: NodeSocketFloat, weight: NodeSocketFloat) -> NodeSocketFloat:
    """
    Allows linear control for the power exponent between 1/n and n.
    """
    decider = nb.add_math_less_than(val=weight, threshold=0.0)
    fixed = nb.add_math_add(
        val0=nb.add_math_abs(weight),
        val1=1.0
    )
    return nb.add_math_power(
        val=value,
        exp=nb.add_math_add(
            val0=nb.add_math_mul(
                val0=nb.add_math_inv(decider),
                val1=fixed
            ),
            val1=nb.add_math_mul(
                val0=decider,
                val1=nb.add_math_div(val0=1.0, val1=fixed)
            )
        )
    )


def __enable_blend_normal(nb: NodeBuilderCycles):
    if not nb.tree_has_input('Blend Normal Weight'):
        nb.add_inputs(
            [
                ( 'FloatFactor', 'Blend Normal Weight', 0.0, 0.0, 1.0 ),
                ( 'Vector', 'Blend Normal', (0.0, 0.0, 1.0) ),
            ],
            after=[ 'Bevel Radius', 'Bump From Max' ],
            hidden={ 'Blend Normal' }
        )    


def _enable_bevel(nb: NodeBuilderCycles, enable: bool):
    name = 'Bevel Radius'
    def_value = 0.05
    if enable:
        if not nb.tree_has_input(name):
            nb.add_inputs(
                [
                    ( 'FloatFactor', name, def_value, 0.0, 1.0 ),
                ],
                after=[ 'Bump From Max' ]
            )
            _update_all_default_values(nb.tree, name, def_value)
    else:
        nb.remove_inputs([ name ])        


def setup_simple_image_texture(
    nb: NodeBuilderCycles,
    image: bpy.types.NodeSocket,
    alpha: bpy.types.NodeSocket
    ):
    """
    Setup all processing on a simple image with HSV.
    """

    hsv = nb.add_hsv(
        color=image,
        hue=nb.input('Hue'),
        saturation=nb.input('Saturation'),
        value=nb.input('Value')
    ).outputs[0]

    gamma = nb.add_gamma(
        color=hsv,
        gamma=1/2.2
    ).outputs[0]

    nb.wire_sockets(hsv, nb.output('Color (sRGB)'))
    nb.wire_sockets(gamma, nb.output('Color (Non-Color)'))
    nb.wire_sockets(alpha, nb.output('Alpha'))
    
    nb.arrange()


def setup_gray_texture(
    nb: NodeBuilderCycles,
    image: bpy.types.NodeSocket,
    size: Tuple[int, int],
    bevel_samples: int = 4
    ):
    """
    Setup all processing on a single gray channel. Used for grunge and effect image nodes.
    """
    tmp = nb.add_map_range(
        value=image,
        clamp=True,
        fmin=nb.input('From Min'),
        fmax=nb.input('From Max'),
        tmin=nb.input('To Min'),
        tmax=nb.input('To Max'),
    ).outputs[0]

    output = add_power_weight(
        nb, 
        tmp, 
        nb.add_math_mul(
            val0=nb.input('Weight'),
            val1=-1.0
        )
    )
    nb.wire_sockets(output, nb.output('Value'))

    tmp2 = nb.add_map_range(
        value=image,
        clamp=True,
        fmin=nb.input('Bump From Min'),
        fmax=nb.input('Bump From Max'),
    ).outputs[0]

    distance = 1.0 / max(size)
    normal = nb.add_bump_map(
        height=add_power_weight(nb, tmp2, nb.input('Bump Weight')),
        strength=nb.input('Bump Strength'),
        distance=nb.add_math_mul(
            val0=distance,
            val1=nb.input('Bump Detail')
        )
    ).outputs[0]

    __enable_blend_normal(nb)
    normal = nb.add_normal_mixer(
        nb.input('Blend Normal'),
        normal,
        nb.input('Blend Normal Weight')
    )

    _enable_bevel(nb, bevel_samples > 0)
    if bevel_samples > 0:
        normal = nb.add_bevel(
            samples=bevel_samples,
            radius=nb.input('Bevel Radius'),
            normal=normal
        ).outputs[0]


    nb.wire_sockets(normal, nb.output('Normal'))
    
    nb.arrange()      


def create_incoming_tangent() -> str:
    """
    Check if this node exists or create it.
    This node is used as view vector, moved to tangent space
    and converted to 2D.
    """
    name = '.snw_incoming_tangent'
    if name not in bpy.data.node_groups:
        nb = NodeBuilderCycles()
        nb.build_node_tree(name)      
        nb.add_outputs([
            ( 'Vector', 'Vector' ),
        ])

        geom = nb.add_geometry()
        tangent = nb.add_tangent().outputs['Tangent']
        bitangent = nb.add_normal_map(strength=1, color=(0.5, 1.0, 0.5, 1.0)).outputs['Normal']
        normal = geom.outputs['Normal']
        incoming = geom.outputs['Incoming']

        x = nb.add_vector_math_dot(tangent, incoming)
        y = nb.add_vector_math_dot(bitangent, incoming)
        z = nb.add_vector_math_dot(normal, incoming)

        result = nb.add_vector_math_div(
            nb.add_combine_xyz(x, y, z).outputs['Vector'],
            nb.add_math_abs(z)
        )

        nb.wire_sockets(result, nb.output('Vector'))

        nb.arrange()

    return name
