# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

from ...operators.helpers.pbr_map_calculator import PbrMap, gamma
from ...utils.node_builder_cycles import NodeBuilderCycles
from ...snw.node_info import encode_pbr_map_info


def create_pbr_map(name: str) -> NodeBuilderCycles:
    """
    Create an empty PBR Map node.
    """
    nb = NodeBuilderCycles()
    nb.build_node_tree(name)
    nb.add_inputs(
        [
            ( 'FloatFactor', 'Pattern', 0.5, 0.0, 1.0 ),
            ( 'FloatFactor', 'Hue', 0.5, 0.0, 1.0 ),
            ( 'FloatFactor', 'Saturation', 1.0, 0.0, 2.0 ),
            ( 'FloatFactor', 'Value', 1.0, 0.0, 2.0 ),
            ( 'FloatFactor', 'Brightness', 0.0, -0.5, 0.5 ),
            ( 'FloatFactor', 'Contrast', 0.0, -0.5, 0.5 ),
            ( 'FloatFactor', 'Roughness Offset', 0.0, -1.0, 1.0 ),
            ( 'Vector', 'Normal', ( 0.0, 0.0, 0.0 ) )
        ],
        hidden={'Normal'}
    )
    nb.add_outputs([
        ( 'Shader', 'Shader' ),
    ])

    return nb


def setup_pbr_map(nb: NodeBuilderCycles, pbr_map: PbrMap):
    """
    Build the whole contents.
    """    
    # Store for edit.
    nb.tree.snw_node_info = encode_pbr_map_info(pbr_map.to_dict())

    # Build color ramps, remove all but one element (can't be removed).
    dcm = nb.add_color_ramp(factor=nb.input('Pattern'))
    dr = dcm.color_ramp
    while (len(dr.elements) > 1):
        dr.elements.remove(dr.elements[1])
    ccm = nb.add_color_ramp(factor=nb.input('Pattern'))
    cr = ccm.color_ramp
    while (len(cr.elements) > 1):
        cr.elements.remove(cr.elements[1])
    ccms = nb.add_separate_rgb(ccm.outputs['Color'])

    percentages = pbr_map.distribution_map()
    current = 0.0
    for i, c in enumerate(pbr_map.entries):
        de = dr.elements[0] if i == 0 else dr.elements.new(current)
        ce = cr.elements[0] if i == 0 else cr.elements.new(current)

        # Blender automatically adjusts gamma, we need to prevent this,
        # not really sure if this is correct.
        de.color = (
            gamma(c.r / 255.0, 2.2), 
            gamma(c.g / 255.0, 2.2), 
            gamma(c.b / 255.0, 2.2),
            1
        )
        ce.color = (
            gamma(c.metallic / 255.0, 2.2), 
            gamma(c.specular / 255.0, 2.2), 
            gamma(c.roughness / 255.0, 2.2), 
            1
        )

        current += percentages[i]

    hsv = nb.add_hsv(
        hue=nb.input('Hue'),
        saturation=nb.input('Saturation'),
        value=nb.input('Value'),
        color=dcm.outputs['Color']
    ).outputs[0]

    bc = nb.add_brightness_contrast(
        brightness=nb.input('Brightness'),
        contrast=nb.input('Contrast'),
        color=hsv
    ).outputs[0]

    shd = nb.add_principled_bsdf(
        basecolor=bc,
        metallic=ccms.outputs[0],
        specular=ccms.outputs[1],
        roughness=nb.add_math_add(
            val0=ccms.outputs[2],
            val1=nb.input('Roughness Offset'),
            clamp=True
        ),
        normal=nb.input('Normal')
    ).outputs[0]

    nb.wire_sockets(shd, nb.output('Shader'))

    nb.arrange()
