## @file calibration.py
## @brief Auto-calibration module for projector overlap detection.

import cv2
import numpy as np
import time
from typing import Tuple, Optional


## @class CalibrationManager
## @brief Handles auto-calibration of projector overlap using a camera.
##
## Detailed Description
## This class uses computer vision to detect misalignment between two projectors.
## It analyzes camera-captured images using SIFT feature detection and FLANN-based matching
## to calculate horizontal displacement between left and right projected images.
##
## Public Member Functions:
## - __init__(camera_index=0): Constructor
## - capture_image(): Captures frame from camera
## - calculate_overlap(img_left, img_right): Calculates overlap using SIFT
## - simulate_calibration(image_width, current_overlap): Simulates calibration
##
## Public Attributes:
## - camera_index: Camera device index
## - sift: SIFT feature detector instance
## - flann: FLANN-based feature matcher
class CalibrationManager:
    
    ## @brief Constructor & Destructor Documentation
    ##
    ## Initializes the calibration manager.
    ##
    ## @param camera_index Camera device index (default: 0)
    def __init__(self, camera_index: int = 0):
        self.camera_index = camera_index  ##< @brief Camera device index
        self.sift = cv2.SIFT_create()     ##< @brief SIFT feature detector
        
        FLANN_INDEX_KDTREE = 1
        index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
        search_params = dict(checks=50)
        self.flann = cv2.FlannBasedMatcher(index_params, search_params)


    ## @brief Member Function Documentation
    ##
    ## Captures a single frame from the camera.
    ##
    ## @return Captured frame as numpy array, or None if capture fails
    def capture_image(self) -> Optional[np.ndarray]:
        """
        Captures a single frame from the camera.
        Returns None if capture fails.
        """
        cap = cv2.VideoCapture(self.camera_index)
        if not cap.isOpened():
            print(f"Error: Could not open camera {self.camera_index}")
            return None
        
        for _ in range(10):
            cap.read()
            
        ret, frame = cap.read()
        cap.release()
        
        if not ret:
            print("Error: Could not read frame")
            return None
            
        return frame


    ## @brief Member Function Documentation
    ##
    ## Calculates horizontal overlap between two images using SIFT features.
    ##
    ## @param img_left Image from left projector only
    ## @param img_right Image from right projector only
    ##
    ## @return Horizontal overlap error in pixels, or None if insufficient features
    def calculate_overlap(self, img_left: np.ndarray, img_right: np.ndarray) -> Optional[int]:
        """
        Calculates the horizontal overlap between two images using SIFT features.
        img_left: The captured image when ONLY the left projector is on.
        img_right: The captured image when ONLY the right projector is on.
        """
        gray_left = cv2.cvtColor(img_left, cv2.COLOR_BGR2GRAY)
        gray_right = cv2.cvtColor(img_right, cv2.COLOR_BGR2GRAY)

        kp1, des1 = self.sift.detectAndCompute(gray_left, None)
        kp2, des2 = self.sift.detectAndCompute(gray_right, None)

        if des1 is None or des2 is None or len(kp1) < 2 or len(kp2) < 2:
            print("Not enough features detected.")
            return None

        matches = self.flann.knnMatch(des1, des2, k=2)

        good_matches = []
        for m, n in matches:
            if m.distance < 0.7 * n.distance:
                good_matches.append(m)

        if len(good_matches) < 4:
            print("Not enough good matches found.")
            return None

        pts_left = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
        pts_right = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)

        H, mask = cv2.findHomography(pts_left, pts_right, cv2.RANSAC, 5.0)
        
        if H is None:
            print("Could not find homography.")
            return None

        matches_mask = mask.ravel().tolist()
        dx_list = []
        for i, match in enumerate(good_matches):
            if matches_mask[i]:
                p1 = kp1[match.queryIdx].pt
                p2 = kp2[match.trainIdx].pt
                dx = p2[0] - p1[0]
                dx_list.append(dx)
        
        if not dx_list:
            return 0
            
        avg_dx = np.median(dx_list)
        return avg_dx


    ## @brief Member Function Documentation
    ##
    ## Simulates the calibration process for testing.
    ##
    ## @param image_width Input image width (unused)
    ## @param current_overlap Current overlap value
    ##
    ## @return Suggested overlap adjustment in pixels
    def simulate_calibration(self, image_width: int, current_overlap: int) -> int:
        """
        Simulates the calibration process.
        Returns the suggested adjustment to the overlap.
        """
        target = 100
        error = target - current_overlap
        noise = np.random.randint(-2, 3)
        return error + noise