# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

import bpy

from typing import Any, List, Tuple

from .parallax import create_parallax_setup
from ...utils.dev import dbg
from ...utils.node_builder_cycles import NodeBuilderCycles
from ...registries.texture_registry import RepositoryType, TextureRegistry
from ...registries.textures.builders.pbr_builder import build_pbr_node_setup, find_displace_min_max
from ...snw.node_info import encode_pbr_info


def create_pbr(name: str) -> NodeBuilderCycles:
    nb = NodeBuilderCycles()
    nb.build_node_tree(name)      

    nb.add_inputs(
        [
            ( 'Vector', 'UV' ),
            ( 'Vector', 'UV Offset' ),
            ( 'Float', 'Scale', 1.0 ),
            # Color Inputs
            ( 'FloatFactor', 'Hue', 0.5, 0.0, 1.0 ),
            ( 'FloatFactor', 'Saturation', 1.0, 0.0, 2.0 ),
            ( 'FloatFactor', 'Value', 1.0, 0.0, 2.0 ),
            ( 'FloatFactor', 'Brightness', 0.0, -0.5, 0.5 ),
            ( 'FloatFactor', 'Contrast', 0.0, -0.5, 0.5 ),
            ( 'FloatFactor', 'AO Intensity', 0.25, 0.0, 1.0 ),
            # Roughness Controls
            ( 'FloatFactor', 'Roughness Offset', 0.0, -1.0, 1.0 ),
            # Normal Controls
            ( 'FloatFactor', 'Normal Strength', 1.0, 0.01, 3.0 ),
            ( 'FloatFactor', 'Blend Normal Weight', 0.0, 0.0, 1.0 ),
            ( 'Vector', 'Blend Normal', (0.0, 0.0, 1.0) ),
            # Others
            ( 'FloatFactor', 'Emission Strength', 0.0, 0.0, 1000.0 ),
            # Displacement Controls
            ( 'FloatFactor', 'Displacement Intensity', 1.0, 0.0, 1.0 ),
            ( 'FloatFactor', 'Displacement Midlevel', 0.5, 0.0, 1.0 ),
        ],
        hidden={ 'UV', 'Blend Normal' }
    )

    nb.add_outputs([
        ( 'Shader' , 'Shader' ),
        ( 'Color', 'Color' ),
        ( 'Color', 'Processed Color' ),
        ( 'Float', 'Alpha' ),
        ( 'Float', 'Metallic' ),
        ( 'Float', 'Specular' ),
        ( 'Float', 'Roughness' ),
        ( 'Vector', 'Normal' ),
        ( 'Color', 'Emission' ),
        ( 'Color', 'Displacement Height'),
        ( 'Vector', 'Displacement' ),
    ])

    return nb

def _update_default_values(
    nodes: List[bpy.types.ShaderNode], 
    tree: bpy.types.ShaderNodeTree, 
    socket: str, 
    value: Any
    ):
    """
    Inserting sockets, e.g. AO, Displacement, ... does set the internal
    default value for the input, but not the value at the instanced node groups.
    So with have to traverse over all nodes in all node trees of every material
    recursively and if there's a node group with our tree, update the input.
    """
    for n in nodes:
        if n.bl_idname == 'ShaderNodeGroup':
            if n.node_tree == tree:
                n.inputs[socket].default_value = value
            else:
                _update_default_values(n.node_tree.nodes, tree, socket, value)


def _update_all_default_values(
    tree: bpy.types.ShaderNodeTree, 
    socket: str, 
    value: Any
    ):
    """
    See function above.
    """
    for m in bpy.data.materials:
        if hasattr(m, 'node_tree') and m.node_tree:
            _update_default_values(m.node_tree.nodes, tree, socket, value)


def _enable_uv_offset(nb: NodeBuilderCycles):
    name = 'UV Offset'
    if not nb.tree_has_input(name):
        nb.add_inputs([ ( 'Vector', name ), ], after=[ 'UV' ])   


def _enable_blend(nb: NodeBuilderCycles, enable: bool):
    name = 'Blend'
    if enable:
        if not nb.tree_has_input(name):
            nb.add_inputs(
                [
                    ( 'FloatFactor', name, 0.5, 0.0, 1.0 ),
                ],
                after=[ 'UV Offset', 'UV' ]
            )
    else:
        nb.remove_inputs([ name ])                 


def _enable_metallic(nb: NodeBuilderCycles, enable: bool):
    name = 'Metallic'
    if enable:
        if not nb.tree_has_output(name):
            nb.add_outputs(
                [
                    ( 'Float', name ),
                ],
                after=[ 'Alpha' ]
            )
    else:
        nb.remove_outputs([ name ])


def _enable_specular(nb: NodeBuilderCycles, enable: bool):
    name = 'Specular'
    if enable:
        if not nb.tree_has_output(name):
            nb.add_outputs(
                [
                    ( 'Float', name ),
                ],
                after=[ 'Metallic', 'Alpha' ]
            )
    else:
        nb.remove_outputs([ name ])


def _enable_ao(nb: NodeBuilderCycles, enable: bool):
    name = 'AO Intensity'
    def_value = 0.25
    if enable:
        if not nb.tree_has_input(name):
            nb.add_inputs(
                [
                    ( 'FloatFactor', name, def_value, 0.0, 1.0 ),
                ],
                after=[ 'Contrast' ]
            )
            _update_all_default_values(nb.tree, name, def_value)
    else:
        nb.remove_inputs([ name ])


def _enable_emission(nb: NodeBuilderCycles, enable: bool):
    name = 'Emission Strength'
    def_value = 0.0
    if enable:
        if not nb.tree_has_input(name):
            nb.add_inputs(
                [
                    ( 'FloatFactor', name, def_value, 0.0, 1000.0 ),
                ],
                after=[ 'Blend Normal' ]
            )
            _update_all_default_values(nb.tree, name, def_value)
        if not nb.tree_has_output('Emission'):
            nb.add_outputs(
                [
                    ( 'Float' , 'Emission' )
                ],
                after=[ 'Normal' ]
            )
    else:
        nb.remove_inputs([ name ])      
        nb.remove_outputs([ 'Emission' ])


def _enable_displacement(nb: NodeBuilderCycles, enable: bool):
    name_i0 = 'Displacement Intensity'
    name_i1 = 'Displacement Midlevel'
    name_o0 = 'Displacement Height'
    name_o1 = 'Displacement'

    if enable:
        if not nb.tree_has_input(name_i0):
            nb.add_inputs(
                [
                    ( 'FloatFactor', name_i0, 1.0, 0.0, 1.0 ),
                    ( 'FloatFactor', name_i1, 0.5, 0.0, 1.0 ),
                ],
                after=[ 'Emission Strength', 'Blend Normal' ]
            )
            _update_all_default_values(nb.tree, name_i0, 1.0)
            _update_all_default_values(nb.tree, name_i1, 0.5)
        if not nb.tree_has_output(name_o0):
            nb.add_outputs(
                [
                    ( 'Color', name_o0, ),
                    ( 'Vector', name_o1 ),
                ],
                after=[ 'Emission', 'Normal' ]
            )
    else:
        nb.remove_inputs([ name_i0, name_i1 ])      
        nb.remove_outputs([ name_o0, name_o1 ])


def _enable_parallax(nb: NodeBuilderCycles, enable: bool):
    name_i0 = 'Parallax Min'
    name_i1 = 'Parallax Max'
    name_i2 = 'Parallax Depth'
    #name_i3 = 'Parallax Normal Mix'
    name_o0 = 'Parallax UV'
    name_o1 = 'Parallax MinMax'

    if enable:
        if not nb.tree_has_input(name_i0):
            nb.add_inputs(
                [
                    ( 'FloatFactor', name_i0, 0.0, 0.0, 1.0 ),
                    ( 'FloatFactor', name_i1, 1.0, 0.0, 1.0 ),
                    ( 'FloatFactor', name_i2, 1.0, 0.0, 8.0 ),
                    #( 'FloatFactor', name_i3, 0.5, 0.0, 1.0 ),
                ],
                after=[ 'Bevel Radius', 'Emission Strength', 'Blend Normal' ]
            )
            _update_all_default_values(nb.tree, name_i0, 0.0)
            _update_all_default_values(nb.tree, name_i1, 1.0)
            _update_all_default_values(nb.tree, name_i2, 1.0)
            #_update_all_default_values(nb.tree, name_i3, 0.5)
        if not nb.tree_has_output(name_o0):
            nb.add_outputs(
                [
                    ( 'Vector', name_o0, ),
                    ( 'Float', name_o1 ),
                ],
                after=[ 'Emission', 'Normal' ]
            )
    else:
        nb.remove_inputs([ name_i0, name_i1, name_i2 ]) #, name_i3 ])
        nb.remove_outputs([ name_o0, name_o1 ])


def _enable_bevel(nb: NodeBuilderCycles, enable: bool):
    name = 'Bevel Radius'
    def_value = 0.05
    if enable:
        if not nb.tree_has_input(name):
            nb.add_inputs(
                [
                    ( 'FloatFactor', name, def_value, 0.0, 1.0 ),
                ],
                after=[ 'Normal Strength' ]
            )
            _update_all_default_values(nb.tree, name, def_value)
    else:
        nb.remove_inputs([ name ])


def _update_parallax_min_max(nb: NodeBuilderCycles, mm: Tuple[float, float]):
    name_i0 = 'Parallax Min'
    name_i1 = 'Parallax Max'

    if nb.tree_has_input(name_i0) and nb.tree_has_input(name_i1):
        _update_all_default_values(nb.tree, name_i0, mm[0])
        _update_all_default_values(nb.tree, name_i1, mm[1])


def setup_pbr(
    nb: NodeBuilderCycles,
    category: str,
    pbr: str,
    variant: int,
    bin_hash: str, 
    mapping: str,
    blend: float = 0.5, 
    interpolation: str = 'Linear',
    anti_repeat_style: str = 'OFF',
    anti_repeat_scale: float = 2,
    anti_repeat_distortion: float = 0.1,
    anti_repeat_style_param_0: float = 0.1,
    anti_repeat_seed: float = 0,
    depth_mode: str = 'NONE',
    pom_levels: int = 8,
    pom_fine_levels: int = 4,
    bevel_samples: int = 4,
    invert_normal: bool = False
    ):

    dbg(f'Update PBR: {category}/{pbr}/{variant} with {mapping}/{interpolation} and depth={depth_mode}/{pom_levels}')
    imagex = TextureRegistry.instance().get_image(RepositoryType.PBR, category, pbr)

    nb.tree.name = imagex.name
    nb.tree.snw_node_info = encode_pbr_info(
        '', 
        category, 
        pbr, 
        variant, 
        bin_hash, 
        mapping, 
        blend,
        interpolation, 
        anti_repeat_style,
        anti_repeat_scale,
        anti_repeat_distortion,
        anti_repeat_style_param_0,
        anti_repeat_seed,
        depth_mode, # NONE, PARALLAX, POM, DISPLACE
        str(pom_levels),
        str(pom_fine_levels),
        bevel_samples,
        invert_normal,
    )

    pbr_data = imagex.info.pbr
    if variant > 0:
        pbr_data = pbr_data.variants[variant - 1]

    if mapping == 'UV':
        uv = nb.add_auto_uv(nb.input('UV'), nb.input('Scale'))
    elif mapping == 'LBOX':
        uv = nb.add_local_box_uv_mapping(nb.input('Scale'))
    elif mapping == 'GENBOX':
        uv = nb.add_generate_box_uv_mapping(nb.input('Scale'))            
    elif mapping == 'GBOX':
        uv = nb.add_global_box_uv_mapping(nb.input('Scale'))
    else:
        raise Exception('Unknown UV Mapping')
    
    proj = 'FLAT' if mapping == 'UV' else 'BOX'

    _enable_uv_offset(nb) # For older nodes.
    uv = nb.add_vector_math_add(uv, nb.input('UV Offset'))

    _enable_blend(nb, proj == 'BOX')
    blend_socket = nb.input('Blend') if proj == 'BOX' else None

    if anti_repeat_style != 'OFF':
        uv = nb.add_uv_anti_repeat(
            uv,
            anti_repeat_style,
            anti_repeat_scale,
            anti_repeat_distortion,
            anti_repeat_style_param_0,
            anti_repeat_seed
        )


    _enable_ao(nb, pbr_data.has_ao())
    ao_intensity = None
    if pbr_data.has_ao():
        ao_intensity = nb.input('AO Intensity')

    _enable_specular(nb, pbr_data.has_specular())
    _enable_metallic(nb, pbr_data.has_metallic())

    # One or none can be active.
    is_parallax = depth_mode in [ 'PARALLAX', 'POM' ] and pbr_data.has_displace()
    is_displace = depth_mode == 'DISPLACE' and pbr_data.has_displace()

    _enable_displacement(nb, is_displace)
    _enable_parallax(nb, is_parallax)

    # Used to blend parallax normal to PBR normal
    parallax_nrm = None
    parallax_nrm_weight = None
    if is_parallax:
        # Auto set min/max parallax range, read from displace image.
        _update_parallax_min_max(nb, find_displace_min_max(pbr_data))

        uv, minMax = create_parallax_setup(
            nb,
            pbr_data,
            depth_mode,
            pom_levels,
            pom_fine_levels,
            uv,
            proj,
            nb.input('Parallax Min'),
            nb.input('Parallax Max'),
            nb.input('Parallax Depth'),
        )
        nb.wire_sockets(uv, nb.output('Parallax UV'))
        nb.wire_sockets(minMax, nb.output('Parallax MinMax'))

    _enable_emission(nb, pbr_data.has_emission())
    emission_strength = None
    if pbr_data.has_emission():
        emission_strength = nb.input('Emission Strength')

    _enable_bevel(nb, bevel_samples > 0)
    bevel_socket = None
    if bevel_samples > 0:
        bevel_socket = nb.input('Bevel Radius')

    outputs = build_pbr_node_setup(
        nb,
        pbr_data,
        uv,
        proj,
        blend_socket,
        interpolation,
        nb.input('Hue'),
        nb.input('Saturation'),
        nb.input('Value'),
        nb.input('Brightness'),
        nb.input('Contrast'),
        ao_intensity,
        nb.input('Roughness Offset'),
        nb.input('Normal Strength'),
        blend_normal_object=nb.input('Blend Normal'),
        blend_normal_object_weight=nb.input('Blend Normal Weight'),
        blend_normal_parallax=parallax_nrm,
        blend_normal_parallax_weight=parallax_nrm_weight,
        emission_strength=emission_strength,
        default_roughness=0.0,
        enable_displace=is_displace,
        bevel_distance=bevel_socket,
        bevel_samples=bevel_samples,
        invert_normal=invert_normal,
    )
    
    nb.wire_sockets(outputs['diffuse_raw'], nb.output('Color'))
    nb.wire_sockets(outputs['diffuse'], nb.output('Processed Color'))
    nb.wire_sockets(outputs['alpha'], nb.output('Alpha'))
    nb.wire_sockets(outputs['roughness'], nb.output('Roughness'))
    nb.wire_sockets(outputs['shader'], nb.output('Shader'))
    if outputs['normal_object']:
        nb.wire_sockets(outputs['normal_object'], nb.output('Normal'))
    if pbr_data.has_specular():
        nb.wire_sockets(outputs['specular'], nb.output('Specular'))
    if pbr_data.has_metallic():
        nb.wire_sockets(outputs['metallic'], nb.output('Metallic'))
    if pbr_data.has_emission():
        nb.wire_sockets(outputs['emission'], nb.output('Emission'))
    if is_displace:
        displacement = nb.add_displacement(
            height=outputs['displace'],
            midlevel=nb.input('Displacement Midlevel'),
            scale=nb.input('Displacement Intensity')
        ).outputs[0]
        nb.wire_sockets(outputs['displace'], nb.output('Displacement Height'))
        nb.wire_sockets(displacement, nb.output('Displacement'))

    nb.arrange() 
