# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

import bpy

from bpy.types import NodeSocket

from typing import Dict, Tuple, Union
from numpy import histogram

from ..scanners.pbr_entry import PBREntry
from ....utils.blender import find_or_load_image
from ....utils.node_builder_cycles import NodeBuilderCycles
from ....constants import snw_rel_path


def _add_raw_texture(
    nb: NodeBuilderCycles,
    name: str,
    colorspace: str,
    uv_vec: NodeSocket,
    interpolation: str
    ) -> Tuple[NodeSocket, NodeSocket]:
    """
    Basic texture loading with basic parameters.
    """
    tx = nb.add_image(
        find_or_load_image(name), 
        projection='FLAT',
        colorspace=colorspace,
        blend=0.0, 
        vector=uv_vec,
        interpolation=interpolation,
    )
    return (tx.outputs['Color'], tx.outputs['Alpha'])


def _add_uvw_texture(
    nb: NodeBuilderCycles,
    name: str,
    colorspace: str,
    uv_vec: NodeSocket,
    interpolation: str,
    blend: NodeSocket, 
    invert_y: bool = None, # Normal specific
    scale_z: Union[None, NodeSocket] = None, # Normal specific
    ):
    """
    Add texture with own UVW mapping (3-tex lookup). Returns blended color, alpha and object space normal (if required).
    """
    # Load pre-splitting for UVW -> UV in 3 separate vectors (X, Y, Z).
    uvw_pre = nb.add_load_group(
        '.SNW-UVWPre', 
        snw_rel_path('data', 'nodes', 'essential.blend'),
        [ ( 'Vector', uv_vec ) ]
    )
    
    # Load texture 3 times and apply one of the 3 vectors to each, 
    # collect color & alpha for each direction.
    sides = [ 'YZ', 'XZ', 'XY' ]
    mapping = [ ( 'Blend', blend ) ]
    for side in sides:
        color, alpha = _add_raw_texture(nb, name, colorspace, uvw_pre.outputs[side], interpolation)
        if invert_y:
            color = nb.add_vector_math_abs(nb.add_vector_math_sub((0, 1, 0), color))
        if scale_z:
            color = nb.scale_normal(color, scale_z)
        mapping.append((side, color))
        mapping.append((f'{side}A', alpha))

    # Blend textures together, do dot3 normal mapping.
    uvw_post = nb.add_load_group(
        '.SNW-UVWPost', 
        snw_rel_path('data', 'nodes', 'essential.blend'),
        mapping
    )

    return (uvw_post.outputs['Color'], uvw_post.outputs['Alpha'], uvw_post.outputs['Normal'])


def _add_parallax_displace_texture(
        nb: NodeBuilderCycles,
        name: str,
        uv_vec: NodeSocket,
        projection: str,
        interpolation: str,
        blend: NodeSocket
    ) -> NodeSocket:
    if projection == 'FLAT':
        return _add_raw_texture(nb, name, 'Non-Color', uv_vec, interpolation)[0]
    else:
        return _add_uvw_texture(nb, name, 'Non-Color', uv_vec, interpolation, blend)[0]
    

def find_displace_min_max(pbr: PBREntry) -> Tuple[float, float]:
    """
    Find min and max values from the displace map.
    """
    if pbr.displace:
        img = find_or_load_image(pbr.displace)
        hist = histogram(img.pixels[:][0::4], bins=64)
        px = sorted(filter(lambda e: e[0] > 0, zip(*hist)), key=lambda e: e[1])
        return (px[0][1], px[-1][1])
    return (0.0, 1.0)


def build_pbr_parallax_heightmap(
    pbr: PBREntry,
    interpolation: str,
    projection: str
    ) -> str:
    """
    Build height map node for parallax mapping. Return name of node.
    """
    name = f'.snw_parallax_hmv2_{projection.lower()}_{pbr.name}'
    if name not in bpy.data.node_groups:
        nb = NodeBuilderCycles()
        nb.build_node_tree(name) 
        nb.add_inputs([
            ( 'Vector', 'UV', (0.0, 0.0, 0.0) ),
            ( 'Float', 'Min', 0.0 ),
            ( 'Float', 'Max', 1.0 )
        ])

        blend = None
        if projection != 'FLAT':
            nb.add_inputs([ ( 'FloatFactor', 'Blend', 0.5, 0.0, 1.0 ) ])
            blend = nb.input('Blend')

        nb.add_outputs([
            ( 'Float', 'Height' ),
        ])

        result = nb.add_map_range(
            value=_add_parallax_displace_texture(nb, pbr.displace, nb.input('UV'), projection, interpolation, blend),
            clamp=True,
            fmin=nb.input('Min'),
            fmax=nb.input('Max')
        ).outputs[0]

        nb.wire_sockets(result, nb.output('Height'))
        nb.arrange()

    return name


def build_pbr_parallax_level(
    pbr: PBREntry,
    interpolation: str,
    projection: str
    ) -> str:
    """
    Create a single level node for parallax occlusion mapping with
    the given displace map. Returns name of node. If node exists, 
    creates nothing.
    """
    name = f'.snw_parallax_levelv2_{projection.lower()}_{pbr.name}'
    if name not in bpy.data.node_groups:
        nb = NodeBuilderCycles()
        nb.build_node_tree(name)

        nb.add_inputs([
            ( 'Vector', 'V0', (0.0, 0.0, 0.0) ),
            ( 'Vector', 'V1', (0.0, 0.0, 0.0) ),
            ( 'Float', 'F0', 0.0 ),
            ( 'Float', 'F1', 0.0 ),
            ( 'Float', 'F2', 0.0 ),
            ( 'Float', 'F3', 1.0 ),
        ])

        blend = None
        if projection != 'FLAT':
            nb.add_inputs([ ( 'FloatFactor', 'Blend', 0.5, 0.0, 1.0 ) ])
            blend = nb.input('Blend')

        nb.add_outputs([
            ( 'Vector', 'V0' ),
            ( 'Vector', 'V1' ),
            ( 'Float', 'F0' ),
            ( 'Float', 'F1' ),
            ( 'Float', 'F2' ),
            ( 'Float', 'F3' ),
            ( 'Float', 'F4' ),
        ])

        img = _add_parallax_displace_texture(nb, pbr.displace, nb.input('V0'), projection, interpolation, blend)

        choice = nb.add_math_mul(
            nb.add_math_less_than(img, nb.input('F0')),
            nb.input('F3')
        )

        nb.wire_sockets(
            nb.add_vector_math_add(nb.input('V0'), nb.input('V1')),
            nb.output('V0')
        )
        nb.wire_sockets(nb.input('V1'), nb.output('V1'))

        nb.wire_sockets(
            nb.add_math_sub(nb.input('F0'), nb.input('F1')),
            nb.output('F0')
        )
        nb.wire_sockets(nb.input('F1'), nb.output('F1'))

        nb.wire_sockets(
            nb.add_math_add(nb.input('F2'), choice),
            nb.output('F2')
        )

        nb.wire_sockets(choice, nb.output('F3'))
        nb.wire_sockets(img, nb.output('F4'))

        nb.arrange()

    return name


def build_pbr_parallax_occlusion(
    pbr: PBREntry,
    interpolation: str,
    projection: str
    ) -> str:    
    """
    Create a parallax occlusion mapping weight node with
    the given displace map. Returns name of node. If node exists, 
    creates nothing.
    """
    name = f'.snw_parallax_occlusionv2_{projection.lower()}_{pbr.name}'
    if name not in bpy.data.node_groups:
        nb = NodeBuilderCycles()
        nb.build_node_tree(name)

        nb.add_inputs([
            ( 'Vector', 'V0', (0.0, 0.0, 0.0) ),
            ( 'Vector', 'V1', (0.0, 0.0, 0.0) ),
            ( 'Float', 'F0', 0.0 ),
            ( 'Float', 'F1', 0.0 ),
            ( 'Float', 'F2', 0.0 ),
            ( 'Float', 'Blend', 0.5 )
        ])

        blend = None
        if projection != 'FLAT':
            nb.add_inputs([ ( 'FloatFactor', 'Blend', 0.5, 0.0, 1.0 ) ])
            blend = nb.input('Blend')

        nb.add_outputs([
            ( 'Vector', 'V0' ),
        ])

        cnt_before = nb.add_math_sub(nb.input('F2'), 1)

        uv_before = nb.add_vector_math_add(
            nb.input('V0'),
            nb.add_vector_math_scale(nb.input('V1'), cnt_before)
        )
        uv_after = nb.add_vector_math_add(uv_before, nb.input('V1'))

        img_before = _add_parallax_displace_texture(nb, pbr.displace, uv_before, projection, interpolation, blend)
        img_after = _add_parallax_displace_texture(nb, pbr.displace, uv_after, projection, interpolation, blend)

        depth_before = nb.add_math_sub(
            nb.input('F0'),
            nb.add_math_mul(nb.input('F1'), cnt_before)
        )

        depth_after = nb.add_math_sub(depth_before, nb.input('F1'))

        before = nb.add_math_sub(depth_before, img_before)
        after = nb.add_math_sub(img_after, depth_after)
        weight = nb.add_math_div(
            before,
            nb.add_math_add(before, after)
        )

        nb.wire_sockets(
            nb.add_vector_math_add(
                uv_before,
                nb.add_vector_math_scale(nb.input('V1'), weight)
            ),
            nb.output('V0')
        )

        nb.arrange()

    return name


def build_pbr_node_setup(
    nb: NodeBuilderCycles,
    pbr: PBREntry,
    uv: NodeSocket,
    projection: str,
    blend: NodeSocket,
    interpolation: str,
    hue: NodeSocket,
    saturation: NodeSocket,
    value: NodeSocket,
    brightness: NodeSocket,
    contrast: NodeSocket,
    ao_intensity: NodeSocket,
    roughness_offset: NodeSocket,
    normal_strength: NodeSocket,
    blend_normal_object: NodeSocket = None,
    blend_normal_object_weight: NodeSocket = None,
    blend_normal_parallax: NodeSocket = None,
    blend_normal_parallax_weight: NodeSocket = None,
    displacement_strength: NodeSocket = None,
    emission_strength: NodeSocket = None,
    is_dx_normal: bool = False,
    default_roughness: float = 0.5,
    enable_displace: bool = False,
    bevel_distance: NodeSocket = None,
    bevel_samples: int = 4,
    invert_normal: bool = False
    ) -> Dict[str, NodeSocket]:
    """
    Uniquely setup the PBR entry into the given node tree.
    """
    
    
    def add_texture_with_alpha(name: str, colorspace: str = 'Non-Color') -> Tuple[NodeSocket, NodeSocket]:
        """
        Load texture map (FLAT | BOX) and return color and alpha.
        """
        if projection == 'FLAT':
            return _add_raw_texture(nb, name, colorspace, uv, interpolation)
        else:
            color, alpha, _ = _add_uvw_texture(nb, name, colorspace, uv, interpolation, blend)
            return (color, alpha)

    
    def add_texture_normal(name: str, invert_y: bool) -> Tuple[NodeSocket, NodeSocket]:
        """
        Load normal map (FLAT | BOX) and return color and normal in object space.
        """
        if projection == 'FLAT':
            color = _add_raw_texture(nb, name, 'Non-Color', uv, interpolation)[0]
            if invert_y:
                color = nb.add_vector_math_abs(nb.add_vector_math_sub((0, 1, 0), color))
            normal = nb.add_normal_map(strength=normal_strength, color=color).outputs[0]
            return (color, normal)
        else:
            color, _, normal = _add_uvw_texture(nb, name, 'Non-Color', uv, interpolation, blend, invert_y, normal_strength)
            return (color, normal)
        
    
    # Use to add texture setup for a single channel, returns just color.
    def add_texture(name: str, colorspace: str = 'Non-Color') -> bpy.types.NodeSocketColor:
        """
        Load texture map (FLAT | BOX) and return color.
        """
        return add_texture_with_alpha(name, colorspace)[0]


    # This is returned and will be wired by the caller.
    outputs = {} # type: Dict[str, NodeSocket]
    
    # Special format which contains AO, Roughness and Metal in one texture file (Polyhaven).
    ao, roughness, metallic = None, None, None
    if pbr.arm:
        arm_rgb = nb.add_separate_xyz(add_texture(pbr.arm))
        ao, roughness, metallic = arm_rgb.outputs[0], arm_rgb.outputs[1], arm_rgb.outputs[2]

    # Setup diffuse with HSV & Contrast + AO, store alpha for later use.
    if pbr.diffuse != None:
        basecolor, tx_alpha = add_texture_with_alpha(pbr.diffuse, colorspace='sRGB')
        outputs['diffuse_raw'] = basecolor
        basecolor = nb.add_hsv(hue=hue, saturation=saturation, value=value, color=basecolor).outputs['Color']
        basecolor = nb.add_brightness_contrast(brightness=brightness, contrast=contrast, color=basecolor).outputs['Color']
        if ao: # ARM
            basecolor = nb.add_hsv(value=ao, factor=ao_intensity, color=basecolor).outputs[0]
        elif pbr.ao:
            ao = add_texture(pbr.ao)
            basecolor = nb.add_hsv(value=ao, factor=ao_intensity, color=basecolor).outputs[0]
    else:
        basecolor = nb.add_color((1.0, 0.0, 1.0, 1.0)).outputs[0]

    # Either use alpha texture, alpha from diffuse texture or constant 1.0.
    if pbr.alpha != None:
        alpha = add_texture(pbr.alpha)
    elif pbr.diffuse != None:
        alpha = tx_alpha
    else:
        alpha = nb.add_value(1.0).outputs[0]

    # Use roughness or inverted glossiness.
    if not roughness: # ARM
        if pbr.roughness != None:
            roughness = add_texture(pbr.roughness)
        elif pbr.glossiness != None:
            roughness = nb.add_math_sub(1.0, add_texture(pbr.glossiness))
        else:
            roughness = nb.add_value(default_roughness).outputs[0]
    roughness = nb.add_math_add(roughness, roughness_offset, clamp=True)

    # Load other channels.
    if not metallic: # ARM
        metallic = add_texture(pbr.metal) if (pbr.metal) else nb.add_value(0.0).outputs[0]
    specular = add_texture(pbr.specular) if (pbr.specular) else nb.add_value(0.5).outputs[0]
    emission = add_texture(pbr.emission) if (pbr.emission) else nb.add_color((0.0, 0.0, 0.0, 1.0)).outputs[0]

    # In case with have displace and should do real displacement mapping.
    displace = None
    if pbr.displace != None and enable_displace:
        displace = add_texture(pbr.displace)
        if displacement_strength:
            displace = nb.add_mix_color('MIX', fac=displacement_strength, color0=(0, 0, 0, 1), color1=displace)

    if pbr.normal != None:
        _, normal = add_texture_normal(pbr.normal, is_dx_normal ^ invert_normal)
    elif pbr.height != None:
        normal = nb.add_bump_map(add_texture(pbr.height), strength=normal_strength, distance=0.001)
    else:
        normal = nb.add_normal_map().outputs[0]

    # Eventually mix with normal from node inputs.
    if blend_normal_object and blend_normal_object_weight:
        normal = nb.add_normal_mixer(
            blend_normal_object,
            normal,
            blend_normal_object_weight
        )

    # Eventually mix with parallax normal.
    if blend_normal_parallax and blend_normal_parallax_weight:
        normal = nb.add_normal_mixer(
            blend_normal_parallax,
            normal,
            blend_normal_parallax_weight
        )

    # Eventually add bevel.
    if bevel_distance:
        normal = nb.add_bevel(samples=bevel_samples, radius=bevel_distance, normal=normal).outputs[0]

    # Create the final shader.
    shader = nb.add_principled_bsdf(
        basecolor=basecolor,
        metallic=metallic,
        specular=specular,
        roughness=roughness,
        normal=normal,
        alpha=alpha,
        emission=emission
    )
    if emission_strength:
        nb.wire_sockets(emission_strength, shader.inputs['Emission Strength'])

    # Store for output mapping.
    outputs['diffuse'] = basecolor
    outputs['alpha'] = alpha
    outputs['metallic'] = metallic
    outputs['specular'] = specular
    outputs['roughness'] = roughness
    outputs['emission'] = emission
    outputs['normal_object'] = normal
    outputs['shader'] = shader.outputs[0]
    if displace:
        outputs['displace'] = displace

    return outputs    
