# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

import bpy

from bpy.types import Operator
from bpy.props import StringProperty, IntProperty, FloatProperty, BoolProperty

import numpy as np

from ..utils.dev import err,log_exception
from ..utils.node_builder_cycles import NodeBuilderCycles


class UI_OT_support_mask_bake(Operator):
    """
    Bakes AOE map.
    """
    bl_idname = 'snw.support_mask_bake'
    bl_label = 'Bake AO/Edge Mask'
    bl_description = 'Bake AO/Edge Mask of current object to a texture'
    bl_options = {'REGISTER'}


    image: StringProperty() # type: ignore
    resolution: IntProperty() # type: ignore
    uv_set: StringProperty() # type: ignore

    ao_samples: IntProperty() # type: ignore
    ao_distance: FloatProperty() # type: ignore
    bevel_samples: IntProperty() # type: ignore
    bevel_distance: FloatProperty() # type: ignore

    cycles_samples: IntProperty() # type: ignore
    cycles_denoise: BoolProperty() # type: ignore

    bake_margin: IntProperty() # type: ignore

    post_blur: BoolProperty() # type: ignore
    disable_modifiers: BoolProperty() # type: ignore


    def _create_image(self):
        """
        Creates image or updates the size of an existing image.
        """
        if self.image in bpy.data.images:
            i = bpy.data.images[self.image]
            if i.size[0] != self.resolution or i.size[1] != self.resolution:
                # Packed seems not be possible to resize this way.
                #if i.packed_files:
                #    i.unpack('REMOVE')
                i.generated_width = self.resolution
                i.generated_height = self.resolution
            return i

        return bpy.data.images.new(self.image, self.resolution, self.resolution, alpha=False)


    def _create_material(self, image: bpy.types.Image):
        """
        Creates material that outputs the AOE material including the 
        target image texture.
        """
        name = '__snm_bake_material'
        if name in bpy.data.materials:
            bpy.data.materials.remove(bpy.data.materials[name])
        m = bpy.data.materials.new(name)
        m.use_nodes = True
        
        nb = NodeBuilderCycles()
        nb.tree_from_material(m)
        nb.clear_nodes()

        red = nb.add_math_sub(
            val0=1.0,
            val1=nb.add_ambient_occlusion(
                samples=self.ao_samples,
                distance=self.ao_distance,
                only_local=True
            ).outputs[0],
            clamp=True
        )

        green = nb.add_math_sub(
            val0=1.0,
            val1=nb.add_ambient_occlusion(
                samples=self.ao_samples,
                distance=self.ao_distance,
                inside=True,
                only_local=True
            ).outputs[0],
            clamp=True
        )

        blue = nb.add_math_add(
            val0=nb.add_vector_math_length(
                val=nb.add_vector_math_cross(
                    val0=nb.add_geometry().outputs['Normal'],
                    val1=nb.add_bevel(
                        samples=self.bevel_samples,
                        radius=self.bevel_distance
                    ).outputs[0]
                )
            ),
            val1=0,
            clamp=True
        )

        nb.add_output_material(
            surface=nb.add_emission(
                color=nb.add_combine_rgb(
                    r = red,
                    g = green,
                    b = blue
                ).outputs[0]
            ).outputs[0]
        )

        i = nb.add_image(
            image=image,
            colorspace='Non-Color',
            vector=nb.add_uv_map(self.uv_set).outputs[0]
        )

        for n in m.node_tree.nodes:
            n.select = False 
        i.select = True

        nb.arrange()

        return m


    def _delete_material(self):
        """
        Deletes the temporary created bake material.
        """
        name = '__snm_bake_material'
        if name in bpy.data.materials:
            bpy.data.materials.remove(bpy.data.materials[name])


    def _bake(self, context, bake_material: bpy.types.Material):
        """
        The whole bake process.
        """
        scene = context.scene # type: bpy.types.Scene
        r = scene.render
        c = scene.cycles

        # Store and set back settings.
        renderer = r.engine
        preview_samples = c.preview_samples
        use_denoising = c.use_denoising
        r.engine = 'CYCLES'
        c.preview_samples = self.cycles_samples
        c.use_denoising = self.cycles_denoise

        # Do those many try/excepts to return to previous state.
        try:
            # Just select the current object.
            bpy.ops.object.mode_set(mode='OBJECT')
            selected = context.selected_objects
            [ o.select_set(False) for o in selected ]
            o = context.active_object # type: bpy.types.Object
            o.select_set(True)

            try:
                # Store material info and set bake material.
                original = []
                if len(o.material_slots) > 0:
                    for i, s in enumerate(o.material_slots):
                        original.append(s.material)
                        o.material_slots[i].material = bake_material
                else:
                    o.data.materials.append(bake_material)

                try:
                    # Eventually disable modifiers.
                    modifiers = []
                    if self.disable_modifiers:
                        for m in o.modifiers:
                            if m.show_render:
                                modifiers.append(m)
                                m.show_render = False

                    try:
                        bpy.ops.object.bake(
                            type='EMIT',
                            margin=self.bake_margin,
                            use_clear=True
                        )
                    except Exception as ex:
                        log_exception(ex, context_msg='Bake failed')

                    # Restore modifier state.
                    for m in modifiers:
                        m.show_render = True

                except Exception as ex:
                    log_exception(ex, context_msg='Bake modifier handling failed')

                # Restore materials.
                if original:
                    for i, m in enumerate(original):
                        o.material_slots[i].material = original[i]
                else:
                    o.data.materials.pop(index=0)
            except Exception as ex:
                log_exception(ex, context_msg='Bake material handling failed')

            # Restore selection.
            [ o.select_set(True) for o in selected ]
        except Exception as ex:
            log_exception(ex, context_msg='Bake object handling failed')

        # Restore bake settings.
        r.engine = renderer
        c.preview_samples = preview_samples
        c.use_denoising = use_denoising


    def _blur_channel(self, a, w, h):
        """
        Blur single (RGB) channel.
        """
        # 1D to 2D.
        a = a.reshape(h, w) 

        # Create the kernel.
        kernel = np.array([ 1.0, 2.0, 1.0 ])
        kernel = kernel / np.sum(kernel)

        # Apply in both axis.
        a = np.apply_along_axis(lambda x: np.convolve(x, kernel, mode='same'), 0, a)
        a = np.apply_along_axis(lambda x: np.convolve(x, kernel, mode='same'), 1, a)

        # Return 2D to 1D version.
        return a.reshape(w * h)        


    def _post_blur(self, image: bpy.types.Image):
        """
        Blurs the image, channel by channel.
        """
        w, h = image.size

        # Extract RGB channels individually.
        pixels = image.pixels[:]
        r = np.array(pixels)[0::4]
        g = np.array(pixels)[1::4]
        b = np.array(pixels)[2::4]

        # Blur each channel.
        r = self._blur_channel(r, w, h)
        g = self._blur_channel(g, w, h)
        b = self._blur_channel(b, w, h)

        # Rebuild RGBA array, with alpha=1.
        p = np.ravel(np.column_stack((r, g, b, np.ones((w * h), dtype=float)))).tolist()

        # Apply to new image.
        image.pixels[:] = p


    def execute(self, context):
        image = self._create_image()
        bake_material = self._create_material(image)
        self._bake(context, bake_material)
        self._delete_material()

        if self.post_blur:
            self._post_blur(image)

        return {'FINISHED'}
