"""
@file altmain.py
@brief Projection Splitter with Overlap Blending using PyTorch.

@details
This module provides a projection image processing system designed for 
multi-projector setups or panoramic displays. It takes an input image, 
splits it into two halves (left and right), and applies an overlap blending 
region between them to ensure seamless projection alignment.

It supports multiple blending modes (linear, quadratic, and Gaussian) 
and performs GPU-accelerated computation via PyTorch for high efficiency.
"""

import cv2
import numpy as np
import math
from scipy.special import erf
import torch
import config_reader


class ProjectionSplit:
    """
    @brief Handles image splitting and blending for projection alignment.

    @details
    The ProjectionSplit class processes both static images and individual 
    frames from video sources. It divides the input image into two parts 
    with a configurable overlap and applies smooth blending transitions 
    using selected mathematical models.

    The blending operations are optimized with PyTorch and support the 
    Metal backend for GPU acceleration on macOS devices.
    """

    def __init__(self):
        """@brief Initialize the ProjectionSplit object and configuration.

        @details
        This constructor initializes placeholders for the left, right, 
        and main images, and loads configuration parameters (e.g., 
        blending coefficients) from the `config.ini` file via ConfigReader.
        """
        self.image_left = None
        self.image_right = None
        self.image_main = None
        self.cfg = config_reader.ConfigReader("config.ini")

    def process_frame(self, image, overlap: int = 75, blend_type: str = "exponential"):
        """@brief Process a single input frame into left and right projections with overlap blending.

        @details
        This method divides an input image frame into two halves and applies 
        a blending function to the overlapping region between them. The 
        blending type determines the transition smoothness between projectors.

        Available blend types:
        - **linear**: Simple linear transition.
        - **quadratic**: Smoother parabolic blending.
        - **gaussian**: Natural soft transition curve based on Gaussian distribution.

        The blending is executed on GPU via PyTorch for efficient computation.

        @param image The input image (NumPy array) in BGR or BGRA format.
        @param overlap Integer specifying the pixel width of the overlapping area. Default: 75.
        @param blend_type String specifying the blending function ("linear", "quadratic", "gaussian").
        @throws FileNotFoundError If the image is None or not found.
        @throws ValueError If an invalid blend_type is specified.
        """
        if image is None:
            raise FileNotFoundError("Error: input.png not found or could not be loaded.")
        self.image_main = image

        # Ensure alpha channel
        if image.shape[2] == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)

        height, width = image.shape[:2]
        split_x = width // 2

        # Define overlapping regions
        left_end = split_x + overlap // 2
        right_start = split_x - overlap // 2

        left_img = image[:, :left_end].copy()
        right_img = image[:, right_start:].copy()

        # Generate normalized overlap vector
        x = np.linspace(0, 1, overlap)

        # Select blending function
        if blend_type == "linear":
            alpha_left_curve = 1 - self.cfg.get_linear_parameter() * x
            alpha_right_curve = 1 - self.cfg.get_linear_parameter() + self.cfg.get_linear_parameter() * x
        elif blend_type == "quadratic":
            alpha_left_curve = (1 - x) ** 2
            alpha_right_curve = x ** 2
        elif blend_type == "gaussian":
            sigma = 0.25
            g = 0.5 * (1 + erf((x - 0.5) / (sigma * np.sqrt(2))))
            alpha_left_curve = 1 - g
            alpha_right_curve = g
        else:
            raise ValueError(f"Unknown blend_type '{blend_type}'")

        # GPU accelerated blending using PyTorch
        device = "mps"  # Metal backend (for macOS)
        left_img_t = torch.from_numpy(left_img).to(device, dtype=torch.float32)
        right_img_t = torch.from_numpy(right_img).to(device, dtype=torch.float32)
        alpha_left_t = torch.from_numpy(alpha_left_curve).to(device, dtype=torch.float32)
        alpha_right_t = torch.from_numpy(alpha_right_curve).to(device, dtype=torch.float32)

        # Expand alpha for broadcast along image width
        alpha_left_2d = alpha_left_t.unsqueeze(0).unsqueeze(-1)
        alpha_right_2d = alpha_right_t.unsqueeze(0).unsqueeze(-1)

        # Apply blending on alpha channel
        left_img_t[:, -overlap:, 3] *= alpha_left_2d.squeeze(-1)
        right_img_t[:, :overlap, 3] *= alpha_right_2d.squeeze(-1)

        # Convert back to CPU for saving
        left_img = left_img_t.cpu().numpy().astype(np.uint8)
        right_img = right_img_t.cpu().numpy().astype(np.uint8)

        self.image_left = left_img
        self.image_right = right_img

        cv2.imwrite("left.png", left_img)
        cv2.imwrite("right.png", right_img)

    def process_images(self, overlap: int = 75, blend_type: str = "exponential"):
        """@brief Process a static image file and generate blended left/right outputs.

        @details
        Reads 'input.png' from the current working directory, applies 
        image splitting and overlap blending, and saves the processed 
        halves as 'left.png' and 'right.png'.

        This function uses the same internal logic as process_frame() 
        but is intended for static image files instead of real-time frames.

        @param overlap Integer pixel width of the overlapping region. Default: 75.
        @param blend_type String specifying blending mode ("linear", "quadratic", "gaussian").
        @throws FileNotFoundError If 'input.png' is not found.
        @throws ValueError If an invalid blending type is selected.
        """
        image = cv2.imread("input.png", cv2.IMREAD_UNCHANGED)
        self.image_main = image
        if image is None:
            raise FileNotFoundError("Error: input.png not found.")

        # Ensure image has alpha channel
        if image.shape[2] == 3:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2BGRA)

        height, width = image.shape[:2]
        split_x = width // 2

        left_end = split_x + overlap // 2
        right_start = split_x - overlap // 2

        left_img = image[:, :left_end].copy()
        right_img = image[:, right_start:].copy()

        # Create blend curve
        x = np.linspace(0, 1, overlap)
        if blend_type == "linear":
            alpha_left_curve = 1 - self.cfg.get_linear_parameter() * x
            alpha_right_curve = 1 - self.cfg.get_linear_parameter() + self.cfg.get_linear_parameter() * x
        elif blend_type == "quadratic":
            alpha_left_curve = (1 - x) ** 2
            alpha_right_curve = x ** 2
        elif blend_type == "gaussian":
            sigma = 0.25
            g = 0.5 * (1 + erf((x - 0.5) / (sigma * np.sqrt(2))))
            alpha_left_curve = 1 - g
            alpha_right_curve = g
        else:
            raise ValueError(f"Unknown blend_type '{blend_type}'")

        # GPU blending with PyTorch
        device = "mps"
        left_img_t = torch.from_numpy(left_img).to(device, dtype=torch.float32)
        right_img_t = torch.from_numpy(right_img).to(device, dtype=torch.float32)
        alpha_left_t = torch.from_numpy(alpha_left_curve).to(device, dtype=torch.float32)
        alpha_right_t = torch.from_numpy(alpha_right_curve).to(device, dtype=torch.float32)

        alpha_left_2d = alpha_left_t.unsqueeze(0).unsqueeze(-1)
        alpha_right_2d = alpha_right_t.unsqueeze(0).unsqueeze(-1)

        left_img_t[:, -overlap:, 3] *= alpha_left_2d.squeeze(-1)
        right_img_t[:, :overlap, 3] *= alpha_right_2d.squeeze(-1)

        left_img = left_img_t.cpu().numpy().astype(np.uint8)
        right_img = right_img_t.cpu().numpy().astype(np.uint8)

        self.image_left = left_img
        self.image_right = right_img

        cv2.imwrite("left.png", left_img)
        cv2.imwrite("right.png", right_img)
