import cv2
import numpy as np
from ultralytics import YOLO  # Import YOLO from ultralytics package
import supervision as sv
from django.db import transaction
from myprofile.models import Camera, CameraEvent, Roi, Alert, Organization, Location, Area
from django.core.management.base import BaseCommand
from django.core.files.base import ContentFile
from multiprocessing import Process
from django.conf import settings
import requests

def danger_detection(camera_id):
    # Initialize the YOLOv8 model
    model = YOLO("yolov8n.pt")  # Load YOLOv8 model from ultralytics

    # Initialize tracker
    tracker = sv.ByteTrack()

    # Open the video file
    camera = Camera.objects.filter(id=camera_id).first()
    if camera:
        video_path = camera.rtsp_url
    else:
        video_path = "/home/nettyfy/visnx/visnx-backend/Nettyfy_visnx/videos/danger_test.mp4"
    cap = cv2.VideoCapture(video_path)

    # Check if the video capture object is opened successfully
    if not cap.isOpened():
        print("Error: Could not open video.")
        return

    # Get video properties
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Define the output video path
    output_video_path = "/home/nettyfy/visnx/visnx-backend/Nettyfy_visnx/videos/danger_output.mp4"

    # Get ROI coordinates
    roi_objs = Roi.objects.filter(camera=camera,camera_events__camera_event='Danger Zone').order_by('id')
    roi_coords_list = [np.array(roi_obj.coordinates['coordinates'], dtype=np.int32) for roi_obj in roi_objs]

    people_enter_queue = {}
    skip_frames = 15

    frame_skip_counter = 0
    frame_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        frame_skip_counter += 1
        if frame_skip_counter <= skip_frames:
            continue
        frame_skip_counter = 0

        # Resize the frame to (1000, 700)
        frame = cv2.resize(frame, (1000, 700))

        # Annotate the frame with ROI
        annotated_frame = frame.copy()
        for roi_coords in roi_coords_list:
            cv2.drawContours(annotated_frame, [roi_coords], -1, (255, 255, 0), 2)

        # Run YOLOv8 tracking on the frame
        results = model(frame)

        if isinstance(results, list) and len(results) > 0:
            result = results[0]
            boxes = np.array(result.boxes.xyxy)
            scores = np.array(result.boxes.conf)
            class_ids = np.array(result.boxes.cls).astype(int)
            detections = sv.Detections(xyxy=boxes, confidence=scores, class_id=class_ids)
        else:
            detections = sv.Detections()

        detections = tracker.update_with_detections(detections)
        boxes = detections.xyxy

        if type(detections.tracker_id) == np.ndarray:
            track_ids = detections.tracker_id

            for box, track_id in zip(boxes, track_ids):
                x1, y1, x2, y2 = map(int, box)
                x = (x1 + x2) / 2
                y = (y1 + y2) / 2

                if cv2.pointPolygonTest(roi_coords_list[0], (x, y), False) > 0:
                    print(f"Person ID {track_id} entered in danger zone")
                    cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    camera_event_obj = CameraEvent.objects.filter(camera_event = 'Danger Zone').first()
                    _, encoded_frame = cv2.imencode('.jpg', annotated_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
                    image_file = ContentFile(encoded_frame, name="alert_image.jpg")
                    org_obj = Organization.objects.all().first()
                    loc_obj = Location.objects.filter(loc_name = camera.location.loc_name).first() if camera.location else ''
                    area_obj = Area.objects.filter(area_name = camera.area.area_name).first() if camera.area else ''
                    alert = Alert(
                        organization = org_obj,
                        camera = camera,
                        detection_choice = '',
                        alert_message = "Person detected in danger zone",
                        frame=image_file,
                        camera_events = camera_event_obj
                    )
                    if area_obj:
                        alert.area = area_obj
                    if loc_obj:
                        alert.location = loc_obj
                    alert.save()

                    camera_incharge_list = Camera.objects.get(id=camera_id).camera_user.all()
                    for user in camera_incharge_list:
                        alert.camera_incharge.add(user)
                    alert.save()

                    # Define the API endpoint
                    url = settings.URL+'/alert_image_upload'

                    # Replace this with the actual ContentFile you have
                    # image_content = ContentFile(b'image_file', 'test_image.jpg')

                    # Create a dictionary with the file
                        
                    files = {'image': ((alert.frame.name.replace('alert_images/', '')), encoded_frame)}

                    # Make the request
                    response = requests.post(url, files=files)

                    # Print the response
                    print(response.json())
                else:
                    print("Outside:", track_id)

        # cv2.imshow("Annotated Frame", annotated_frame)

        # if cv2.waitKey(1) & 0xFF == ord('q'):
        #     break

        frame_count += 1

    cap.release()
    # cv2.destroyAllWindows()
    print(f"Output video saved at: {output_video_path}")

class Command(BaseCommand):
    help = 'Queue detection based on video analysis'

    def handle(self, *args, **options):
        cameras = Camera.objects.filter(is_enabled=True)

        # for camera in cameras:
        #     danger_detection(camera.id)
        processes = []  # Keep track of the processes so we can wait on them later.

        for device in cameras:
            # Start a new process for each device.
            process = Process(target=danger_detection, args=(device.id,))
            processes.append(process)
            process.start()            
        
        print(processes)
        for process in processes:
            process.join()
