# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

"""
Used both in the Addon directly, but also from update_settings script externally.
"""

import bpy, os, hashlib, sys, traceback
from dataclasses import dataclass

from math import ceil
from mathutils import Vector

from typing import Any, Dict, List, Tuple, Union


# Used in Packer class
@dataclass
class PackerRectangle:
    user: Any 
    width: float
    height: float
    node: 'PackerNode' = None


# Used in Packer class
@dataclass
class PackerNode:
    x: float 
    y: float
    width: float
    height: float
    used: bool = False
    down: 'PackerNode' = None
    right: 'PackerNode' = None


def log_exception_shared(exc: Exception, *, context_msg: str = ""):
    log_prefix = "[AWP]"
    msg = f"{log_prefix} {context_msg} Error: {exc}" if context_msg else f"{log_prefix} Error: {exc}"
    traceback.print_exc(file=sys.stderr)
    print(msg)            


class Packer:
    """
    Evenly distribute source_rects in an automatically growing squarish rect, used for 
    automatically placing assets in library scene.
    """
    def __init__(self, source_rects: List[PackerRectangle]):
        rects = list(reversed(sorted(reversed(sorted(source_rects, key=lambda r: r.width)), key=lambda r: r.height)))
        self.w = rects[0].width if rects else 0
        self.h = rects[0].height if rects else 0
        self.root = PackerNode(0, 0, self.w, self.h)

        for r in rects:
            n = self.find_node(self.root, r.width, r.height)
            if n:
                r.node = self.split_node(n, r.width, r.height)
            else:
                r.node = self.grow_node(r.width, r.height)


    def find_node(self, node: PackerNode, width: float, height: float):
        if node.used:
            right = self.find_node(node.right, width, height)
            if right: return right
            return self.find_node(node.down, width, height)
        elif width <= node.width and height <= node.height:
            return node


    def split_node(self, node: PackerNode, width: float, height: float) -> PackerNode:
        node.used = True
        node.down = PackerNode(node.x, node.y + height, node.width, node.height - height)
        node.right = PackerNode(node.x + width, node.y, node.width - width, height)
        return node


    def grow_node(self, width: float, height: float) -> Union[PackerNode, None]:
        can_grow_down = width <= self.root.width
        can_grow_right = height <= self.root.height
        should_grow_down = can_grow_down and self.root.width >= self.root.height + height
        should_grow_right = can_grow_right and self.root.height >= self.root.width + width

        if should_grow_right: return self.grow_right(width, height)
        if should_grow_down: return self.grow_down(width, height)
        if can_grow_right: return self.grow_right(width, height)
        if can_grow_down: return self.grow_down(width, height)


    def grow_right(self, width: float, height: float) -> Union[PackerNode, None]:
        self.root = PackerNode(
            0, 
            0, 
            self.root.width + width, 
            self.root.height, 
            down=self.root, 
            right=PackerNode(self.root.width, 0, width, self.root.height),
            used=True
        )
        n = self.find_node(self.root, width, height)
        if (n):
            return self.split_node(n, width, height)


    def grow_down(self, width: float, height: float) -> Union[PackerNode, None]:
        self.root = PackerNode(
            0, 
            0, 
            self.root.width, 
            self.root.height + height, 
            down=PackerNode(0, self.root.height, self.root.width, height),
            right=self.root, 
            used=True
        )
        n = self.find_node(self.root, width, height)
        if (n):
            return self.split_node(n, width, height)


def create_image_hash(image: bpy.types.Image, packed_only: bool) -> Union[str, None]:
    """
    Create hash from image source (file on disk or packed data)
    """
    # If packed, use the packed data to create the hash.
    if image.packed_file and image.packed_file.data:
        return hashlib.blake2s(image.packed_file.data).hexdigest()

    # First try to use external data.
    if not packed_only and image.filepath:
        path = os.path.realpath(bpy.path.abspath(image.filepath))
        #print(f'Create hash from file: {image.filepath} -> {bpy.path.abspath(image.filepath)} -> {path}')
        if os.path.exists(path):
            with open(path, 'rb') as f:
                return hashlib.blake2s(f.read()).hexdigest()

    return None


def remove_duplicate_images(packed_only: bool):
    """
    Remove double or more occurences of images using binary comparision.
    """
    # Build dict of images by their binary hash. Identical images
    # will end in the same entry.
    hashes = {} # type: Dict[str, List[bpy.types.Image]]
    for i in bpy.data.images:
        hash = create_image_hash(i, packed_only)
        if hash:
            if hash in hashes:
                hashes[hash].append(i)
            else:
                hashes[hash] = [i]

    # Eliminte duplicates.
    for images in hashes.values():
        if len(images) > 1: # Have duplicates of this image?
            # We prefer a packed one to use.
            used_image = images[0]
            for i in images:
                if i.packed_file and i.packed_file.data:
                    used_image = i
                    break
            
            # Assign image 'used_image' to all others.
            for i in images:
                if i != used_image:
                    i.user_remap(used_image.id_data)


def bounding_box(objects: List[bpy.types.Object]) -> Tuple[Tuple[int, int, int], Tuple[int, int, int]]:
    """
    Find bounding box enclosing the given object, return dimension and center.
    """
    if objects:
        vecs = []
        for o in objects:
            try:
                vecs.extend([ o.matrix_world @ Vector(corner) for corner in o.bound_box ])
            except Exception as e:
                log_exception_shared(e, context_msg=f'Failed to get bounding box for object: {o.name}')

        if vecs:
            bmix = min([ v[0] for v in vecs ])
            bmiy = min([ v[1] for v in vecs ])
            bmiz = min([ v[2] for v in vecs ])
            bmax = max([ v[0] for v in vecs ])
            bmay = max([ v[1] for v in vecs ])
            bmaz = max([ v[2] for v in vecs ])

            #print(f'{objects}: ({bmix}, {bmiy}, {bmiz}) <-> ({bmax}, {bmay}, {bmaz})')

            dx, dy, dz = bmax - bmix, bmay - bmiy, bmaz - bmiz
            cx, cy, cz = dx/2 + bmix, dy/2 + bmiy, dz/2 + bmiz

            return ((dx, dy, dz), (cx, cy, cz))
        #else:
        #    return ((0, 0, 0), objects[0].location)
    else:
        return ((0, 0, 0), (0, 0, 0))


def do_place(assets: List[Union[bpy.types.Collection, bpy.types.Object]], padding: float = 0.2):
    if not assets: return

    # Find the largest dimension from all assets to place.
    rects = [] # type: List[PackerRectangle]
    for a in assets:
        try:
            if isinstance(a, bpy.types.Object):
                dimension, center = bounding_box([a])

                # Center is not automatically location (origin), take this into account.
                offset_x = a.location[0] - center[0]
                offset_y = a.location[1] - center[1]
                offset_z = a.location[2] - (center[2] - dimension[2] / 2)

            elif isinstance(a, bpy.types.Collection):
                dimension, center = bounding_box(a.all_objects)

                # Center is not automatically location (origin), take this into account.
                offset_x = a.all_objects[0].location[0] - center[0]
                offset_y = a.all_objects[0].location[1] - center[1]
                offset_z = a.all_objects[0].location[2] - (center[2] - dimension[2] / 2)
        
            rects.append(
                PackerRectangle(
                    user = (a, offset_x, offset_y, offset_z), 
                    width = dimension[0] + 2 * padding, 
                    height = dimension[1] + 2 * padding
                )
            )
        except Exception as e:
            log_exception_shared(e, context_msg=f'Failed to process asset: {a.name}')

    # Determine positions.
    packer = Packer(rects)
    # Coordinates are 0+, center both axis, so center is 0, 0
    center_offset_x = packer.root.width / 2
    center_offset_y = packer.root.height / 2

    # Place.
    for r in rects:
        a: Union[bpy.types.Collection, bpy.types.Object]
        a, offset_x, offset_y, offset_z = r.user
        try:
            # Find center on rectangle.
            x = r.node.x + r.width / 2
            y = r.node.y + r.height / 2

            if isinstance(a, bpy.types.Object) and not a.parent:
                a.location = (x + offset_x - center_offset_x, y + offset_y - center_offset_y, offset_z)
            elif isinstance(a, bpy.types.Collection):
                # Collections are somewhat special,
                # objects must stay relative to each other.
                objs = a.all_objects
                if objs:
                    rel = (
                        x - objs[0].location[0] + offset_x - center_offset_x,
                        y - objs[0].location[1] + offset_y - center_offset_y,
                        0 - objs[0].location[2] + offset_z
                    )
                    for o in objs:
                        if not o.parent:
                            o.location = (
                                o.location[0] + rel[0],
                                o.location[1] + rel[1],
                                o.location[2] + rel[2],
                            )
        except Exception as e:
            log_exception_shared(e, context_msg=f'Failed to place asset: {a.name}')


def auto_place(padding: float = 0.2, update_collection_insert_point: bool = False):
    """
    Places all asset-related object in a grid.
    """
    # Collect all collections & objects.
    assets = []

    # Collections.
    for c in bpy.data.collections:
        if c.asset_data:
            assets.append(c)

    # Objects.
    for o in bpy.data.objects:
        if o.asset_data:
            assets.append(o)

    # Objects that contain materials with asset mark.
    for m in bpy.data.materials:
        if m.asset_data:
            for o in bpy.data.objects:
                found = False
                for ms in o.material_slots:
                    if ms.material == m:
                        found = True
                        break
                if found:
                    assets.append(o)

    do_place(assets, padding)

    if update_collection_insert_point:
        bpy.context.view_layer.update()
        for c in bpy.data.collections:
            if c.asset_data:
                d, ce = bounding_box(c.all_objects)
                #print(f'{d} -> {ce}')
                io = Vector(ce) + Vector(d) * Vector((0, 0, -0.5))
                c.instance_offset = io
                #print(f'Coll: {c.name} -> {io} -> {c.instance_offset}')
