#!/usr/bin/env python
# -*- coding: utf-8 -*-

import cv2
import numpy as np
import os
from config_reader import ConfigReader


class MainAlphaBlender(object):
    def __init__(self, config_path="config.json"):
        try:
            self.__config_reader = ConfigReader(config_path)
            self.blend_width = self.__config_reader.get_blend_width()
            self.gamma_value = self.__config_reader.get_gamma_value()
            self.method = self.__config_reader.get_blend_method()
            self.output_dir = self.__config_reader.get_output_dir()
            self.preview = self.__config_reader.get_preview()
            self.image_path = self.__config_reader.get_image_path()
        except FileNotFoundError:
            self.blend_width = 200
            self.gamma_value = 1.4
            self.method = "cosine"
            self.output_dir = "Results"
            self.preview = True
            self.image_path = "OriginalImages"
        self.update_paths()

    def update_paths(self):
        self.left_image_path = os.path.join(self.image_path, "Left.jpg")
        self.right_image_path = os.path.join(self.image_path, "Right.jpg")

    def create_alpha_gradient(self, blend_width, side, method="cosine"):
        if method == 'linear':
            alpha_gradient = np.linspace(0, 1, blend_width)
        elif method == 'cosine':
            t = np.linspace(0, np.pi, blend_width)
            alpha_gradient = (1 - np.cos(t**0.85)) / 2
        elif method == 'quadratic':
            t = np.linspace(0, 1, blend_width)
            alpha_gradient = t**2
        elif method == 'sqrt':
            t = np.linspace(0, 1, blend_width)
            alpha_gradient = np.sqrt(t)
        elif method == 'log':
            t = np.linspace(0, 1, blend_width)
            alpha_gradient = np.log1p(9 * t) / np.log1p(9)
        elif method == 'sigmoid':
            t = np.linspace(0, 1, blend_width)
            alpha_gradient = 1 / (1 + np.exp(-12 * (t - 0.5)))
            alpha_gradient = (alpha_gradient - alpha_gradient.min()) / (alpha_gradient.max() - alpha_gradient.min())
        else:
            raise ValueError("Invalid method: choose from 'linear', 'cosine', 'quadratic', 'sqrt', 'log', or 'sigmoid'")
        if side == 'right':
            alpha_gradient = 1 - alpha_gradient
        return alpha_gradient

    def gamma_correction(self, image, gamma):
        img_float = image.astype(np.float32) / 255.0
        mean_intensity = np.mean(img_float)
        adaptive_gamma = gamma * (0.5 / (mean_intensity + 1e-5))
        adaptive_gamma = np.clip(adaptive_gamma, 0.8, 2.0)
        corrected = np.power(img_float, 1.0 / adaptive_gamma)
        return np.uint8(np.clip(corrected * 255, 0, 255))

    def alpha_blend_edge(self, image, blend_width, side, method="cosine"):
        height, width, _ = image.shape
        blended_image = image.copy()
        alpha_gradient = self.create_alpha_gradient(blend_width, side, method)
        if side == 'right':
            roi = blended_image[:, width - blend_width:]
        elif side == 'left':
            roi = blended_image[:, :blend_width]
        else:
            raise ValueError("Side must be 'left' or 'right'")
        gradient_3d = alpha_gradient[np.newaxis, :, np.newaxis]
        gradient_3d = np.tile(gradient_3d, (height, 1, 3))
        if side == 'right':
            blended_image[:, width - blend_width:] = (roi * gradient_3d).astype(np.uint8)
        else:
            blended_image[:, :blend_width] = (roi * gradient_3d).astype(np.uint8)
        return blended_image

    def show_preview(self, left_image, right_image, scale=0.5):
        h = min(left_image.shape[0], right_image.shape[0])
        left_resized = cv2.resize(left_image, (int(left_image.shape[1]*scale), int(h*scale)))
        right_resized = cv2.resize(right_image, (int(right_image.shape[1]*scale), int(h*scale)))
        combined = np.hstack((left_resized, right_resized))
        cv2.imshow("Preview (Left + Right)", combined)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    def run(self):
        try:
            os.makedirs(self.output_dir, exist_ok=True)
            left_img = cv2.imread(self.left_image_path, cv2.IMREAD_COLOR)
            right_img = cv2.imread(self.right_image_path, cv2.IMREAD_COLOR)
            if left_img is None or right_img is None:
                raise FileNotFoundError(f"Could not read images from '{self.image_path}'. Check path.")
            left_blended = self.alpha_blend_edge(left_img, self.blend_width, side='right', method=self.method)
            right_blended = self.alpha_blend_edge(right_img, self.blend_width, side='left', method=self.method)
            left_gamma = self.gamma_correction(left_blended, self.gamma_value)
            right_gamma = self.gamma_correction(right_blended, self.gamma_value)
            left_output_path = os.path.join(self.output_dir, f"{self.method}_left_gamma.jpg")
            right_output_path = os.path.join(self.output_dir, f"{self.method}_right_gamma.jpg")
            cv2.imwrite(left_output_path, left_gamma)
            cv2.imwrite(right_output_path, right_gamma)
            if self.preview:
                self.show_preview(left_gamma, right_gamma)
            return (True, f"Images saved successfully in '{self.output_dir}'.")
        except (FileNotFoundError, ValueError) as e:
            return (False, str(e))
        except Exception as e:
            return (False, f"An unexpected error occurred: {e}")
        finally:
            cv2.destroyAllWindows()

