import cv2
from myprofile.tracker import *
from myprofile.models import *
from django.core.files.base import ContentFile
import django
from django.core.management.base import BaseCommand
from ultralytics import YOLO
from django.conf import settings
from multiprocessing import Process
import requests
from django.conf import settings
import io

base_dir = settings.BASE_DIR

def check_and_trigger_alerts(person_id, labels):

    no_items = [label for label in labels if "no" in label]
    if no_items:
        alert_message = f"ALERT for ID {person_id}: {' '.join(no_items)}"
        print(person_id,"-----------------person_id-----------------")
        # print(person_id.count())
        print(alert_message)

def process_device(camera_id):
    django.db.close_old_connections()
    # Set the output video resolution
    output_width, output_height = 1280, 720
    try:
        close_old_connections()
        camera = Camera.objects.get(id=camera_id,is_enabled=True)
        fps = camera.fps
        print('process',camera.id,'-----------------------------------------running')

        fetch_events = LinkCameraEventToCamera.objects.filter(camera = camera, is_enabled=True).values_list('camera_events__camera_event', flat=True)
        
        camera_event_names = fetch_events
        event_dict = []
        for i in camera_event_names: event_dict.append(f"{i}: yes")
        for i in camera_event_names: event_dict.append(f"{i}: no")
        # event_dict = {}
        # for i in camera_event_names: event_dict.update({i:[]})
        camera_event_names_with_condition = [f"{name}: no" for name in camera_event_names]
        
        data =  event_dict
        data.append('Person')

        rules_classes = [
                'Helmet: yes', 'Jacket: yes', 'Gloves: yes', 'Shoes: yes', 'Person',
                'Helmet: no', 'Jacket: no', 'Gloves: no', 'Shoes: no'
            ]
        
        data_dict = {item: True for item in data}

        classes = [' '] * len(rules_classes)
        for i, item in enumerate(rules_classes):
            if item in data_dict:
                classes[i] = item
        
        cap = cv2.VideoCapture(camera.rtsp_url)

        # cap.set(cv2.CAP_PROP_FRAME_WIDTH, output_width)

        # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        # out = cv2.VideoWriter("08.mp4", fourcc, 30.0, (output_width, output_height))


        model_path = str(base_dir)+"/model/best1.pt"
        model = YOLO(model_path)
        tracker = SimpleTracker()
        skip_frames = 15

        frame_skip_counter = 0

        previous_person_count = 0
        
        skip_counter = 0
        max_skips = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame_skip_counter += 1
            if frame_skip_counter <= skip_frames:
                continue
            frame_skip_counter = 0
            results = model(frame)[0]

            # classes = [
            #     ' ', ' ', ' ', ' ', 'Person',
            #     'Helmet: no', 'Jacket: no', 'Gloves: no', 'Shoes: no',

            # ]

            persons = []
            # Your existing detection logic here...
            for result in results:
                x1, y1, x2, y2, conf, class_idx = result.boxes.data.tolist()[0]
                class_idx = int(class_idx)
                class_label = classes[class_idx]

                if class_label == "Person":
                    # Create a new person object with an empty labels set
                    person = {"bbox": (x1, y1, x2, y2), "labels": set()}
                    persons.append(person)

            current_person_count = len(persons)
            last_track_id = 0
            if current_person_count != previous_person_count:
                # give alert beacuse of change in person count
                print(f"Person count changed from {previous_person_count} to {current_person_count}")
                for result in results:
                    x1, y1, x2, y2, conf, class_idx = result.boxes.data.tolist()[0]
                    class_idx = int(class_idx)
                    class_label = classes[class_idx]

                    if class_label != "Person":
                        cx, cy = (x1 + x2) / 2, (y1 + y2) / 2  # center of the PPE bbox
                        for person in persons:
                            px1, py1, px2, py2 = person["bbox"]
                            if px1 <= cx <= px2 and py1 <= cy <= py2:  # if the center of the PPE bbox is within the Person bbox
                                person["labels"].add(class_label)  # As "labels" is now a set, the same label cannot be added twice
                                break
                for person in persons:
                    # Draw the bbox
                    x1, y1, x2, y2 = person["bbox"]
                    cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0,255, 0), 3) #Draw the person bbox

                    # Filter labels to display based on the presence of "no" labels
                    no_labels = {label for label in person["labels"] if "no" in label}
                    yes_labels = {label for label in person["labels"] if "yes" in label}

                    # Decide which labels to display
                    display_labels = no_labels | yes_labels
                    
                    # Set the initial position for text (above the bbox)
                    text_position = (int(x1), int(y1) - 10)

                    for label in display_labels:
                        # Your existing logic to display the text
                        # Add a black rectangle for padding
                        text_width, text_height = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] 
                        rectangle_bgr = (0, 0, 0)
                        (text_width, text_height) = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1.2, thickness=1)[0]
                        box_coords = ((text_position[0], text_position[1]), (text_position[0] + text_width + 2, text_position[1] - text_height - 2))
                        cv2.rectangle(frame, box_coords[0], box_coords[1], rectangle_bgr, cv2.FILLED)

                        text_color = (0, 255, 0) if label in yes_labels else (0, 0, 255)  # Green for "yes", Red for "no"

                        # Add the text
                        cv2.putText(frame, label, text_position, cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
                        # Update text position for next label
                        text_position = (text_position[0], text_position[1] - 30)

                # Update tracker with the latest detections
                tracker.update_tracks(persons)

                for track_id, person in tracker.tracks.items():     
                    if last_track_id != track_id:
                        last_track_id = track_id
                    
                        detect = tracker.tracks.items()
                        detect = dict(detect)
                            
                        for key in detect:
                            labels = detect[key]['labels']
                            labels_list = [label for label in labels if label.strip()]

                        print(track_id,"------------track--------------") 
                        # print(track_id.count())      
                        # Draw the bbox and tracking info (including ID)
                        x1, y1, x2, y2 = person["bbox"]

                        
                        for equ in camera_event_names_with_condition:
                            if not any(equ.split(':')[0] in item for item in labels_list):
                                labels_list.append(equ.split(':')[0]+': no_detection')
                        is_alert = False
                        for item in labels_list:
                            if item.split(':')[1] == " no":
                                is_alert = True
                        if is_alert:
                            alert_list = []
                            for lable in labels_list:
                                label_text = lable.split(':')[0]
                                camera_event_obj = CameraEvent.objects.filter(camera_event = label_text).first()
                                current_time = datetime.now()
                                formatted_current_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
                                _, encoded_frame = cv2.imencode('.jpg', frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
                                image_file = ContentFile(encoded_frame, name="alert_image.jpg")
                                detect_choice = 'No'
                                if label.split(':')[1] == ' yes':
                                    detect_choice = 'Yes'
                                elif label.split(':')[1] == ' no_detection':
                                    detect_choice = 'NotDetected'
                                else:
                                    detect_choice = 'No'
                                org_obj = Organization.objects.all().first()
                                loc_obj = Location.objects.filter(loc_name = camera.location.loc_name).first() if camera.location else ''
                                area_obj = Area.objects.filter(area_name = camera.area.area_name).first() if camera.area else ''
                                alert = Alert(
                                    organization = org_obj,
                                    camera = camera,
                                    detection_choice = detect_choice,
                                    alert_message = f"Alert : {lable} detected at {formatted_current_time} in Device ID :{camera_id}",
                                    frame=image_file,
                                    camera_events = camera_event_obj
                                )
                                if loc_obj:
                                    alert.location = loc_obj
                                if area_obj:
                                    alert.area = area_obj
                                alert.save()

                                camera_incharge_list = Camera.objects.get(id=camera_id).camera_user.all()
                                for user in camera_incharge_list:
                                    alert.camera_incharge.add(user)
                                alert.save()

                                alert_list.append(alert)


                                # Define the API endpoint
                                url = settings.URL+'/alert_image_upload'

                                # Replace this with the actual ContentFile you have
                                # image_content = ContentFile(b'image_file', 'test_image.jpg')

                                # Create a dictionary with the file
                                
                                files = {'image': ((alert.frame.name.replace('alert_images/', '')), encoded_frame)}

                                # Make the request
                                response = requests.post(url, files=files)

                                # Print the response
                                print(response.json())

                            alert_person_tracking = AlertPersonTracking.objects.create(
                                tracking_id = int(track_id),
                            )
                            for alert in alert_list:
                                alert_person_tracking.alerts.add(alert)
                            alert_person_tracking.save()

                    # cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 3) #Draw the person bbox
                    cv2.putText(frame, f"ID: {track_id}", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) # print id on person bbox
                    check_and_trigger_alerts(track_id, person["labels"])

                # Resize and show the frame
                frame = cv2.resize(frame, (output_width, output_height))
                # out.write(frame)
                # cv2.imshow("yolov8", frame)

                if cv2.waitKey(1) == ord('q'):
                    break

            previous_person_count = current_person_count
    except Exception as e:
        print(f"Error processing device {camera_id}: {str(e)}")
    finally:
        # Clean up resources, if any.
        pass

import time
from multiprocessing import Process
from django.db import close_old_connections

class Command(BaseCommand):
    help = 'Triggers alerts based on video analysis'
    
    def handle(self, *args, **options):
        while True:  # Loop indefinitely
            close_old_connections()  # Close old database connections before starting new processes
            cameras = Camera.objects.filter(is_enabled=True).order_by('created_at')

            processes = []  # Keep track of the processes so we can terminate them later.

            try:
                for device in cameras:
                    print(f"devices: {device.id}")
                    # Start a new process for each device.
                    process = Process(target=process_device, args=(device.id,))
                    processes.append(process)
                    process.start()            

                print(f"Started {len(processes)} processes.")
                
                # Let the processes run for 10 seconds
                time.sleep(300)

            except Exception as e:
                print(f"Error occurred: {str(e)}")

            finally:
                # Terminate all the processes after 10 seconds
                for process in processes:
                    if process.is_alive():
                        process.terminate()  # Terminate the process
                        process.join()  # Ensure the process has finished
                        print(f"Terminated process for camera {device.id}")

                print("Restarting processes...")

                # Wait briefly before restarting to ensure all processes are fully terminated
                time.sleep(2)