# Copyright (C) 2022 Thomas Hoppe (h0bB1T). All rights reserved.
#
# Unauthorized copying of this file via any medium is strictly prohibited.
# Proprietary and confidential.

import colorsys

from typing import Any, Tuple, Dict, List, Union


def gamma(v: float, g: float): 
    """
    Gamma for v in range 0..1.
    """
    return v**g


def hsv(r: int, g: int, b: int) -> Tuple[int, int, int]:
    """
    RGB (0..255) to HSV (0..179/255).
    """
    h, s, v = colorsys.rgb_to_hsv(r / 255, g / 255, b / 255)
    return (h * 179, s * 255, v * 255)


class PbrMapEntry:
    """
    Stores processed, quantized data of a single entry.
    """
    def __init__(self, init: Union[Dict[str, int], None] = None):
        if init:
            self.pixels = init.get('p', 0)

            self.r = init.get('r', 0)
            self.g = init.get('g', 0)
            self.b = init.get('b', 0)

            self.metallic = init.get('me', 0)
            self.specular = init.get('sp', 0)
            self.roughness = init.get('ro', 0)

            self.hue = init.get('hue', 0)
            self.saturation = init.get('saturation', 0)
            self.value = init.get('value', 0)
        else:
            self.pixels = 0

            self.r = 0
            self.g = 0
            self.b = 0

            self.metallic = 0
            self.specular = 0
            self.roughness = 0

            self.hue = 0
            self.saturation = 0
            self.value = 0


    def to_dict(self) -> Dict[str, int]:
        return {
            'p': self.pixels,
            'r': self.r,
            'g': self.g,
            'b': self.b,
            'me': self.metallic,
            'sp': self.specular,
            'ro': self.roughness,
            'hu': self.hue,
            'sa': self.saturation,
            'va': self.value
        }


class PbrMap:
    def __init__(self, entries: Union[List[PbrMapEntry], None] = None):
        self.entries = entries if entries else []
        self.order = ''


    def reorder(self, order: str):
        """
        Adjust order of entries to given value.
        """
        self.order = order
        if self.order == 'H': self.entries = sorted(self.entries, key=lambda x: x.hue)
        elif self.order == 'S': self.entries = sorted(self.entries, key=lambda x: x.saturation)
        elif self.order == 'V': self.entries = sorted(self.entries, key=lambda x: x.value)
        elif self.order == 'R': self.entries = sorted(self.entries, key=lambda x: x.r)
        elif self.order == 'G': self.entries = sorted(self.entries, key=lambda x: x.g)
        elif self.order == 'B': self.entries = sorted(self.entries, key=lambda x: x.b)
        elif self.order == 'RGB': self.entries = sorted(self.entries, key=lambda x: x.r + x.g + x.b)
        elif self.order == 'Me': self.entries = sorted(self.entries, key=lambda x: x.metallic)
        elif self.order == 'Sp': self.entries = sorted(self.entries, key=lambda x: x.specular)
        elif self.order == 'Ro': self.entries = sorted(self.entries, key=lambda x: x.roughness)


    def distribution_map(self) -> List[float]:
        """
        Create map, were each entry gets an occurrence factor based on pixel amount.
        """
        # Use number of pixels in each cluster for amount of space in color ramp.
        pxn = sum([ c.pixels for c in self.entries ])
        percentages = [ c.pixels/pxn for c in self.entries ]

        # Apply power to scale, so less occurring colors get some more space.
        #percentages = [ p**0.1 for p in percentages ]
        #np = sum(percentages)
        #percentages = [ p/np for p in percentages ]
        return percentages


    def to_dict(self) -> Dict[str, Any]:
        return { 
            'entries': list(map(lambda e: e.to_dict(), self.entries)),
            'order': self.order
        }


class PbrCluster:
    """
    Used to collect similar pixels.
    """
    def __init__(self, mask: int):
        # Mask used in octree.
        self.mask = mask

        # Pixels in this cluseter.
        self.pixels = [] # type: List[Tuple[int, int, int, int, int, int]]

        # Mean values (updated in update_mean()).
        self.r_m = 0
        self.g_m = 0
        self.b_m = 0
        self.metallic_m = 0
        self.specular_m = 0
        self.roughness_m = 0

        # HSV of median (updated in update_mean()).
        self.hue_m = 0
        self.saturation_m = 0
        self.value_m = 0


    def update_mean(self):
        """
        Update all mean values.
        """
        r, g, b, metallic, specular, roughness = 0, 0, 0, 0, 0, 0
        for p in self.pixels:
            r += p[0]
            g += p[1]
            b += p[2]
            metallic += p[3]
            specular += p[4]
            roughness += p[5]

        self.r_m = round(r / len(self.pixels))
        self.g_m = round(g / len(self.pixels))
        self.b_m = round(b / len(self.pixels))
        self.metallic_m = round(metallic / len(self.pixels))
        self.specular_m = round(specular / len(self.pixels))
        self.roughness_m = round(roughness / len(self.pixels))

        self.hue_m, self.saturation_m, self.value_m = hsv(self.r_m, self.g_m, self.b_m)


    def to_map_entry(self) -> PbrMapEntry:
        r = PbrMapEntry()
        r.pixels = len(self.pixels)
        r.r = self.r_m
        r.g = self.g_m
        r.b = self.b_m
        r.metallic = self.metallic_m
        r.specular = self.specular_m
        r.roughness = self.roughness_m
        r.hue = self.hue_m
        r.saturation = self.saturation_m
        r.value = self.value_m
        return r


class PbrQuantizer:
    """
    Collect similar pixels (by R, G, B, Metallic & Roughness) in clusters to build
    material mappings.
    """
    def __init__(self):
        # Used to collect info using RGB-Hash.
        self.matrix = {} # type: Dict[int, PbrCluster]
        # Will receive the reduced cluster set.
        self.result = [] # type: List[PbrCluster]


    def __add_pixel(self, r: int, g: int, b: int, metallic: int, specular: int, roughness: int, mask: int = 0xe0):
        """
        Add single source pixel to best cluster. Most time consuming function.
        """
        hash = ((metallic & mask) << 32) | ((roughness & mask) << 24) | ((r & mask) << 16) | ((g & mask) << 8) | (b & mask)

        e = self.matrix.get(hash, None)
        if not e:
            e = PbrCluster(mask)
            self.matrix[hash] = e
        e.pixels.append((r, g, b, metallic, specular, roughness))


    def __refine(self, threshold: int):
        """
        Subdivide largest cluster into next smaller dimension.
        """
        # Find largest cluster by pixel count.
        values = list(self.matrix.values())
        largest = sorted(values, key=lambda x: len(x.pixels), reverse=True)[0]
        largest_key = list(self.matrix.keys())[values.index(largest)]

        if len(largest.pixels) > threshold:
            # Remove from dict.
            del self.matrix[largest_key]

            # Re-add individual pixels with next smaller mask.
            new_mask = 0x80 | (largest.mask >> 1)
            for px in largest.pixels:
                self.__add_pixel(*px, new_mask)


    def __quantize(self, colors: int, max_cluster_size: float) -> PbrMap:
        """
        Reduce number of clusters to max 'colors' and creates a map
        that can be used to build color ramps.
        """
        # Limit refine loops
        max_refines = 64

        # Stop if the largest cluster has less than 20% of all pixels.
        threshold = round(max_cluster_size * sum([ len(e.pixels) for e in self.matrix.values() ]))
        for _ in range(max_refines):
            self.__refine(threshold)

        # Create cluster list and remove empty clusters.
        result = list(filter(lambda x: len(x.pixels) > 0, self.matrix.values()))

        # Update all mean values.
        [ e.update_mean() for e in result ]

        # Sort remaining by number of pixels, the cluster with most pixels
        # at beginning of list. Take first n entries.
        result = sorted(result, key=lambda x: len(x.pixels), reverse=True)
        result = result[:colors]

        return PbrMap(list(map(lambda x: x.to_map_entry(), result)))


    def process(
        self,
        diffuse_map: List[Tuple[int, int, int]], 
        metallic_map: List[int], 
        specular_map: List[int], 
        roughness_map: List[int],
        colors: int,
        max_cluster_size: float
        ) -> PbrMap:
        """
        Do whole processing from source pixel information and return PBR map.
        """
        z = list(zip(diffuse_map, metallic_map, specular_map, roughness_map))
        for e in z:
            self.__add_pixel(e[0][0], e[0][1], e[0][2], e[1], e[2], e[3])

        return self.__quantize(colors, max_cluster_size)


