import cv2
import numpy as np
from ultralytics import YOLO  # Import YOLO from ultralytics package
import supervision as sv
from django.db import transaction
from myprofile.models import *
from django.core.management.base import BaseCommand
from django.core.files.base import ContentFile
from multiprocessing import Process
from django.conf import settings
import requests
import time
from django.db import close_old_connections

def danger_detection(camera_id):
    try:
        close_old_connections()
        # Initialize the YOLOv8 model
        model = YOLO("yolov8n.pt")  # Load YOLOv8 model from ultralytics

        # Initialize tracker
        tracker = sv.ByteTrack()

        # Open the video file
        camera = Camera.objects.filter(id=camera_id).first()
        if camera:
            video_path = camera.rtsp_url
        else:
            video_path = "/home/nettyfy/visnx/visnx-backend/Nettyfy_visnx/videos/danger_test.mp4"
        cap = cv2.VideoCapture(video_path)

        # Check if the video capture object is opened successfully
        if not cap.isOpened():
            print("Error: Could not open video.")
            return

        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)

        # Define the output video path
        # output_video_path = "/home/nettyfy/visnx/visnx-backend/Nettyfy_visnx/videos/danger_output.mp4"

        # Get ROI coordinates
        roi_objs = Roi.objects.filter(camera=camera)
        roi_coords_list = [np.array(roi_obj.coordinates['coordinates'], dtype=np.int32) for roi_obj in roi_objs]

        people_enter_queue = {}
        skip_frames = 15

        frame_skip_counter = 0
        frame_count = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            frame_skip_counter += 1
            if frame_skip_counter <= skip_frames:
                continue
            frame_skip_counter = 0

            # Resize the frame to (1000, 700)
            frame = cv2.resize(frame, (1200, 700))

            # Annotate the frame with ROI
            annotated_frame = frame.copy()
            for roi_coords in roi_coords_list:
                cv2.drawContours(annotated_frame, [roi_coords], -1, (255, 255, 0), 2)

            # Run YOLOv8 tracking on the frame
            results = model(frame)

            if isinstance(results, list) and len(results) > 0:
                result = results[0]
                boxes = np.array(result.boxes.xyxy)
                scores = np.array(result.boxes.conf)
                class_ids = np.array(result.boxes.cls).astype(int)
                detections = sv.Detections(xyxy=boxes, confidence=scores, class_id=class_ids)
            else:
                detections = sv.Detections()

            detections = tracker.update_with_detections(detections)
            boxes = detections.xyxy

            if type(detections.tracker_id) == np.ndarray:
                track_ids = detections.tracker_id

                for box, track_id in zip(boxes, track_ids):
                    x1, y1, x2, y2 = map(int, box)
                    x = (x1 + x2) / 2
                    y = (y1 + y2) / 2

                    if cv2.pointPolygonTest(roi_coords_list[0], (x, y), False) > 0:
                        # print(f"Person ID {track_id} entered in danger zone")
                        cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                        org_obj = Organization.objects.all().first()
                        loc_obj = Location.objects.filter(loc_name = camera.location.loc_name).first() if camera.location else ''
                        area_obj = Area.objects.filter(area_name = camera.area.area_name).first() if camera.area else ''
                        camera_event_obj = CameraEvent.objects.filter(camera_event = 'Danger Zone').first()
                        _, encoded_frame = cv2.imencode('.jpg', annotated_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
                        image_file = ContentFile(encoded_frame, name="alert_image.jpg")
                        alert = Alert(
                            organization = org_obj,
                            camera = camera,
                            detection_choice = '',
                            alert_message = "Person detected in danger zone",
                            frame=image_file,
                            camera_events = camera_event_obj
                        )
                        if area_obj:
                            alert.area = area_obj
                        if loc_obj:
                            alert.location = loc_obj
                        alert.save()

                        camera_incharge_list = Camera.objects.get(id=camera_id).camera_user.all()
                        for user in camera_incharge_list:
                            alert.camera_incharge.add(user)
                        alert.save()

                        # Define the API endpoint
                        url = settings.URL+'/alert_image_upload'

                        # Replace this with the actual ContentFile you have
                        # image_content = ContentFile(b'image_file', 'test_image.jpg')

                        # Create a dictionary with the file
                            
                        files = {'image': ((alert.frame.name.replace('alert_images/', '')), encoded_frame)}

                        # Make the request
                        response = requests.post(url, files=files)

                        # Print the response
                        print(response.json())
                    # else:
                    #     print("Outside:", track_id)

            # cv2.imshow("Annotated Frame", annotated_frame)

            # if cv2.waitKey(1) & 0xFF == ord('q'):
            #     break

            frame_count += 1

        cap.release()
        # cv2.destroyAllWindows()
        # print(f"Output video saved at: {output_video_path}")
    finally:
        # Explicitly close the database connection for the current process
        close_old_connections()

class Command(BaseCommand):
    help = 'Queue detection based on video analysis'

    def handle(self, *args, **options):
        while True:  # Loop indefinitely
            close_old_connections()  # Close old database connections before starting new processes
            cameras = Camera.objects.filter(is_enabled=True).order_by('created_at')

            processes = []  # Keep track of the processes so we can terminate them later.

            try:
                for device in cameras:
                    print(f"devices: {device.id}")
                    # Start a new process for each device.
                    process = Process(target=danger_detection, args=(device.id,))
                    processes.append(process)
                    process.start()            

                print(f"Started {len(processes)} processes.")
                
                # Let the processes run for 10 seconds
                time.sleep(300)

            except Exception as e:
                print(f"Error occurred: {str(e)}")

            finally:
                # Terminate all the processes after 10 seconds
                for process in processes:
                    if process.is_alive():
                        process.terminate()  # Terminate the process
                        process.join()  # Ensure the process has finished
                        print(f"Terminated process for camera {device.id}")

                print("Restarting processes...")

                # Wait briefly before restarting to ensure all processes are fully terminated
                time.sleep(2)