#!/bin/python3
'''
    SPDX-FileCopyrightText: 2025 Agata Cacko <cacko.azh@gmail.com>

    This file is part of Fast Sketch Cleanup Plugin for Krita

    SPDX-License-Identifier: GPL-3.0-or-later
'''


import converter as conv
import numpy as np

from PIL import Image
import openvino.runtime.opset13 as ops
import openvino as ov
import sys


class PreProcessingInfo:
	preprocessingLevelsLowerValue: float = 0.0
	preprocessingLevelsUpperValue: float = 1.0
	resizeInPreprocessingValue: float = 1.0
	invert: bool = False

	def isEqualTo(self, other) -> bool:
		if other is None:
			return False
		return (self.preprocessingLevelsLowerValue == other.preprocessingLevelsLowerValue
			and self.preprocessingLevelsUpperValue == other.preprocessingLevelsUpperValue
			and self.resizeInPreprocessingValue == other.resizeInPreprocessingValue
			and self.invert == other.invert)
	

	def __str__(self):
		return f"Preprocessing: levels = ({self.preprocessingLevelsLowerValue}, {self.preprocessingLevelsUpperValue}), resizeIn..."


class PostProcessingInfo:
	resizeInPreprocessingValue: float = 1.0
	invert: bool = False
	postprocessingLevelsLowerValue: float = 0.0
	postprocessingLevelsUpperValue: float = 1.0
	isSharpenChecked: bool = False
	sharpenStrength: float = 1.0

	def isEqualTo(self, other) -> bool:
		if other is None:
			return False
		return (self.invert == other.invert 
			and self.resizeInPreprocessingValue == other.resizeInPreprocessingValue
			and self.postprocessingLevelsLowerValue == other.postprocessingLevelsLowerValue
			and self.postprocessingLevelsUpperValue  == other.postprocessingLevelsUpperValue
			and self.isSharpenChecked == other.isSharpenChecked
			and self.sharpenStrength == other.sharpenStrength)


	def __str__(self):
		return f"Postprocessing: levels = ({self.postprocessingLevelsLowerValue}, {self.postprocessingLevelsUpperValue}), \
resize in pre: {self.resizeInPreprocessingValue}, invert = {self.invert}, isSharpenChecked = {self.isSharpenChecked}, sharpenStrength = {self.sharpenStrength}"





def applyLevels(numpyInput: np.ndarray, lowerValue: float, upperValue: float) ->  np.ndarray:

	if lowerValue == 0.0 and upperValue == 1.0:
		return numpyInput
	epsilon = 0.000001
	if (abs(upperValue - lowerValue) < epsilon):
		numpyType = numpyInput.dtype
		numpyInput = np.sign(numpyInput - lowerValue).astype(numpyType)
		numpyInput = (numpyInput + 1.0)/2.0 # can only be 0.0, 0.5, and 1.0
		return numpyInput
	return np.clip((numpyInput - lowerValue)/(upperValue - lowerValue), 0.0, 1.0)

def applySharpen(numpyInput: np.ndarray, factor) -> np.ndarray:
	try:
		
		matrix = [[0.0, -2.0, 0.0], [-2.0, 10.0, -2.0], [0.0, -2.0, 0.0]]
		identity = [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]
		matrixFinal = np.array(matrix)
		identityFinal = np.array(identity)
		if (factor < 1.0):
			matrixFinal = factor*matrixFinal + (1 - factor)*(identityFinal)
		else:
			matrixFinal = factor*matrixFinal
		matrixFinal = matrixFinal.astype(numpyInput.dtype)
		matrixFinal = matrixFinal.reshape((1, 1, 3, 3))

		data = ops.parameter(numpyInput.shape, ov.Type.f32)
		filter = ops.parameter(matrixFinal.shape, ov.Type.f32)
		convolution = ops.convolution(data, filter, [1, 1], [1, 1], [1, 1], [1, 1])
		
		result = ops.result(convolution)
		sharpenModel = ov.Model([result], [data, filter], "Sharpen Filter")

		compiledSharpenModel = ov.compile_model(sharpenModel)
		resultMatrixDict = compiledSharpenModel([numpyInput, matrixFinal])
		resultMatrix = resultMatrixDict[0]
		numpyInput = np.clip(resultMatrix, 0.0, 1.0)
		return numpyInput
		
	except Exception as e:
		print(f"Exception during creating sharpen: {e}", file=sys.stderr)
	
	return numpyInput

def applyPreProcessingNumpy(numpyInput: np.ndarray, info: PreProcessingInfo) -> np.ndarray:
	
	numpyInput = applyLevels(numpyInput, info.preprocessingLevelsLowerValue, info.preprocessingLevelsUpperValue)
	scale = info.resizeInPreprocessingValue
	if  scale != 1.0:
		pillowImage = conv.convertNumpyToPillow(numpyInput)
		pillowWidth = pillowImage.width
		pillowHeight = pillowImage.height
		
		pillowImage = pillowImage.resize(size=(int(pillowWidth*scale), int(pillowHeight*scale)), resample=Image.BICUBIC)
		numpyInput = conv.convertPillowToNumpy(pillowImage)

	if (info.invert):
		numpyInput = conv.invert(numpyInput)

	return numpyInput


def applyPostProcessingNumpy(numpyInput: np.ndarray, info: PostProcessingInfo) -> np.ndarray:
	if (info.invert):
		numpyInput = conv.invert(numpyInput)

	scale = info.resizeInPreprocessingValue # the other is just for display, let's say
	if  scale != 1.0:
		scale = 1/scale
		pillowImage = conv.convertNumpyToPillow(numpyInput)
		
		pillowWidth = pillowImage.width
		pillowHeight = pillowImage.height
		
		pillowImage = pillowImage.resize(size=(int(pillowWidth*scale), int(pillowHeight*scale)), resample=Image.BICUBIC)
		numpyInput = conv.convertPillowToNumpy(pillowImage)


	numpyInput = applyLevels(numpyInput, info.postprocessingLevelsLowerValue, info.postprocessingLevelsUpperValue)

	
	if (info.isSharpenChecked):
		factor = info.sharpenStrength
		numpyInput = applySharpen(numpyInput, factor)
	
	return numpyInput



def uninvertIfNeeded(invert, numpyArray: np.ndarray) -> np.ndarray:
	if invert:
		return conv.invert(numpyArray)
	return numpyArray

