# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

from typing import Tuple

from bpy.types import NodeSocketVector, NodeSocketFloat

from .simple import create_incoming_tangent
from ...utils.node_builder_cycles import NodeBuilderCycles
from ...registries.textures.builders.pbr_builder import PBREntry, build_pbr_parallax_heightmap, build_pbr_parallax_level, build_pbr_parallax_occlusion
    

def __create_parallax_basic(
    nb: NodeBuilderCycles,
    pbr_data: PBREntry,
    incoming: NodeSocketVector,
    uv: NodeSocketVector,
    proj: str,
    interpolation: str,
    parallax_min: NodeSocketFloat,
    parallax_max: NodeSocketFloat,
    parallax_depth: NodeSocketFloat,
    ) -> Tuple[NodeSocketVector, NodeSocketFloat]:
    """
    Setup basic parallax mapping, returns (newUV, rawDisplacement).
    """
    height = nb.add_group(build_pbr_parallax_heightmap(pbr_data, interpolation, proj))
    nb.wire_sockets(uv, height.inputs['UV'])
    nb.wire_sockets(parallax_min, height.inputs['Min'])
    nb.wire_sockets(parallax_max, height.inputs['Max'])

    # This is the whole parallax mapping, just bend the view vector more or
    # less, depending on displacement value at current UV.
    plx = nb.add_vector_math_sub(
        uv,
        nb.add_vector_math_scale(
            nb.add_vector_math_scale(
                incoming,
                nb.add_math_mul(parallax_depth, 0.01)
            ),
            nb.add_math_sub(1.0, height.outputs[0])
        )
    )

    # Return new uv and raw displacement info.
    return (plx, height.outputs[0])


def __create_parallax_occlusion(
    nb: NodeBuilderCycles,
    pbr_data: PBREntry,
    pom_levels: int,
    pom_fine_levels: int,
    incoming: NodeSocketVector,
    uv: NodeSocketVector,
    proj: str,
    interpolation: str,
    parallax_min: NodeSocketFloat,
    parallax_max: NodeSocketFloat,
    parallax_depth: NodeSocketFloat,
    ) -> Tuple[NodeSocketVector, NodeSocketFloat]:
    """
    Parallax Occlusion Mapping.
    """
    # Create node for one level.
    level_name = build_pbr_parallax_level(pbr_data, interpolation, proj)

    # Instantiate N levels and connect to each other (0<->1, 1<->2, ...).
    levels = []
    for i in range(pom_levels):
        levels.append(nb.add_group(level_name))
        if i != 0:
            l0 = levels[i-1]
            l1 = levels[i]
            for l in range(6):
                nb.wire_sockets(l0.outputs[l], l1.inputs[l])

    # UV iterator, vector from real surface into virtual 'volume' per level.
    uv_iter = nb.add_vector_math_scale(
        incoming,
        nb.add_math_mul(
            parallax_depth,
            -0.01 / pom_levels # Constant depth, independent to number of levels.
        )
    )                        

    # (max-min)/levels -> depth distance per iteration/level node.
    depth_step = nb.add_math_div(
        nb.add_math_sub(parallax_max, parallax_min),
        pom_levels
    )

    # Connect the first parallax node.
    l0 = levels[0]
    nb.wire_sockets(uv, l0.inputs[0])
    nb.wire_sockets(uv_iter, l0.inputs[1])
    nb.wire_sockets(parallax_max, l0.inputs[2])
    nb.wire_sockets(depth_step, l0.inputs[3])

    # Iterations until hit depth (after).
    coarse_cnt = levels[-1].outputs[4]
    
    # In case of fine levels, add those here.
    if pom_fine_levels > 0:
        # Create and interconnect them.
        fine_levels = []
        for i in range(pom_fine_levels):
            fine_levels.append(nb.add_group(level_name))
            if i != 0:
                l0 = fine_levels[i-1]
                l1 = fine_levels[i]
                for l in range(6):
                    nb.wire_sockets(l0.outputs[l], l1.inputs[l])

        # Create input params for fine level. We start at UV/Max(Top) BEFORE coarse hit
        # and iter with fine uv/steps (/pom_fine_levels).
        coarse_cnt_m1 = nb.add_math_sub(coarse_cnt, 1) # cnt before
        fine_uv = nb.add_vector_math_add(uv, nb.add_vector_math_scale(uv_iter, coarse_cnt_m1))
        fine_uv_iter = nb.add_vector_math_scale(uv_iter, 1.0 / pom_fine_levels)
        fine_max = nb.add_math_sub(parallax_max, nb.add_math_mul(depth_step, coarse_cnt_m1))
        fine_depth_step = nb.add_math_mul(depth_step, 1 / pom_fine_levels)

        # Connect the first fine level node.
        fl0 = fine_levels[0]
        nb.wire_sockets(fine_uv, fl0.inputs[0])
        nb.wire_sockets(fine_uv_iter, fl0.inputs[1])
        nb.wire_sockets(fine_max, fl0.inputs[2])
        nb.wire_sockets(fine_depth_step, fl0.inputs[3])

        # Iterations until hit depth (after), fine.
        fine_cnt = fine_levels[-1].outputs[4]

        occ_uv = fine_uv
        occ_uv_iter = fine_uv_iter
        occ_max = fine_max
        occ_depth_steps = fine_depth_step
        occ_cnt = fine_cnt
    else:
        # Prepare params for occlusion node.
        occ_uv = uv
        occ_uv_iter = uv_iter
        occ_max = parallax_max
        occ_depth_steps = depth_step
        occ_cnt = coarse_cnt

    # Add parallax occlusion node.
    occ = nb.add_group(build_pbr_parallax_occlusion(pbr_data, interpolation, proj))
    nb.wire_sockets(occ_uv, occ.inputs[0])
    nb.wire_sockets(occ_uv_iter, occ.inputs[1])
    nb.wire_sockets(occ_max, occ.inputs[2])
    nb.wire_sockets(occ_depth_steps, occ.inputs[3])
    nb.wire_sockets(occ_cnt, occ.inputs[4])

    # Return new uv and raw displacement info.
    return (occ.outputs[0], l0.outputs['F4'])


def create_parallax_setup(
    nb: NodeBuilderCycles,
    pbr_data: PBREntry,
    depth_mode: str,
    pom_levels: int,
    pom_fine_levels: int,
    uv: NodeSocketVector,  # Incoming UV
    proj: str,             # Texture projection: FLAT | BOX
    parallax_min: NodeSocketFloat,
    parallax_max: NodeSocketFloat,
    parallax_depth: NodeSocketFloat,
    ) -> Tuple[NodeSocketVector, NodeSocketFloat]:
    """
    Complete parallax setup, returns fixed UV and minMax.
    """
    # Use this for height maps.
    interpolation = 'Smart'

    # Depending on mapping, either use raw incoming vector or the one in tangent space.
    if proj == 'FLAT':
        incoming = nb.add_group(create_incoming_tangent()).outputs[0]
    else:
        incoming = nb.add_vector_transform(
            'VECTOR',
            'WORLD',
            'OBJECT',
            nb.add_geometry().outputs['Incoming']
        ).outputs[0]

    # Different modes, different setups.
    if depth_mode == 'PARALLAX':
        uv, par_depth = __create_parallax_basic(
            nb, 
            pbr_data, 
            incoming,
            uv,
            proj,
            interpolation,
            parallax_min,
            parallax_max,
            parallax_depth
        )
    elif depth_mode == 'POM':
        uv, par_depth = __create_parallax_occlusion(
            nb, 
            pbr_data, 
            pom_levels,
            pom_fine_levels,
            incoming,
            uv,
            proj,
            interpolation,
            parallax_min,
            parallax_max,
            parallax_depth
        )

    # Create setup for min/max visualization, a color ramp with low blue and high red
    # mixed 0.5 with displacement map.
    pmm = nb.add_color_ramp(
        nb.add_map_range(par_depth, fmin=parallax_min, fmax=parallax_max).outputs[0],
        interpolation='CONSTANT',
        values=[
            (0.0, (0.0, 0.0, 1.0, 1.0)),  # Blue
            (0.02, (0.5, 0.5, 0.5, 1.0)), # Gray
            (0.98, (1.0, 0.0, 0.0, 1.0)), # Red
        ]
    )
    pmmf = nb.add_mix_color(operation='MIX', fac=0.5, color0=pmm.outputs[0], color1=par_depth).outputs[0]

    return (uv, pmmf)
    # Skipped, doesn't look good, besides that, we must use another 
    # displacement map image texture node with the final UV coods as bump source (not
    # included here).
    #parallax_nrm = nb.add_bump_map(height=par_depth).outputs[0]
    #parallax_nrm_weight = nb.input("Parallax Normal Mix")