import cv2
from myprofile.tracker import *
from myprofile.models import *
from django.core.files.base import ContentFile
import django
from django.core.management.base import BaseCommand
from ultralytics import YOLO
from django.conf import settings
from multiprocessing import Process
import requests
import io

base_dir = settings.BASE_DIR

def check_and_trigger_alerts(person_id, labels):

    no_items = [label for label in labels if "no" in label]
    if no_items:
        alert_message = f"ALERT for ID {person_id}: {' '.join(no_items)}"
        print(person_id,"-----------------person_id-----------------")
        # print(person_id.count())
        print(alert_message)

def process_device(camera_id):
    django.db.close_old_connections()
    # Set the output video resolution
    output_width, output_height = 1280, 720
    try:
        camera = Camera.objects.get(id=camera_id,is_enabled=True)
        fps = camera.fps

        fetch_events = LinkCameraEventToCamera.objects.filter(camera = camera, is_enabled=True).values_list('camera_events__camera_event', flat=True)
        
        camera_event_names = fetch_events
        event_dict = {}
        for i in camera_event_names: event_dict.update({i:[]})
        camera_event_names_with_condition = [f"{name}: no" for name in camera_event_names]

        data =  camera_event_names_with_condition
        data.append('Person')

        rules_classes = [
                'Helmet: yes', 'Jacket: yes', 'Gloves: yes', 'Shoes: yes', 'Person',
                'Helmet: no', 'Jacket: no', 'Gloves: no', 'Shoes: no'
            ]
        
        data_dict = {item: True for item in data}

        classes = [' '] * len(rules_classes)
        for i, item in enumerate(rules_classes):
            if item in data_dict:
                classes[i] = item
        
        cap = cv2.VideoCapture(camera.rtsp_url)

        # cap.set(cv2.CAP_PROP_FRAME_WIDTH, output_width)

        # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        # out = cv2.VideoWriter("08.mp4", fourcc, 30.0, (output_width, output_height))


        model_path = str(base_dir)+"/model/best1.pt"
        model = YOLO(model_path)
        tracker = SimpleTracker()
        skip_frames = 15

        frame_skip_counter = 0

        previous_person_count = 0
        
        skip_counter = 0
        max_skips = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame_skip_counter += 1
            if frame_skip_counter <= skip_frames:
                continue
            frame_skip_counter = 0
            results = model(frame)[0]

            # classes = [
            #     ' ', ' ', ' ', ' ', 'Person',
            #     'Helmet: no', 'Jacket: no', 'Gloves: no', 'Shoes: no',

            # ]

            persons = []

            # Your existing detection logic here...
            for result in results:
                x1, y1, x2, y2, conf, class_idx = result.boxes.data.tolist()[0]
                class_idx = int(class_idx)
                class_label = classes[class_idx]

                if class_label == "Person":
                    # Create a new person object with an empty labels set
                    person = {"bbox": (x1, y1, x2, y2), "labels": set()}
                    persons.append(person)

            current_person_count = len(persons)
            last_track_id = 0
            if current_person_count != previous_person_count:
                # give alert beacuse of change in person count
                print(f"Person count changed from {previous_person_count} to {current_person_count}")
                for result in results:
                    x1, y1, x2, y2, conf, class_idx = result.boxes.data.tolist()[0]
                    class_idx = int(class_idx)
                    class_label = classes[class_idx]

                    if class_label != "Person":
                        cx, cy = (x1 + x2) / 2, (y1 + y2) / 2  # center of the PPE bbox
                        for person in persons:
                            px1, py1, px2, py2 = person["bbox"]
                            if px1 <= cx <= px2 and py1 <= cy <= py2:  # if the center of the PPE bbox is within the Person bbox
                                person["labels"].add(class_label)  # As "labels" is now a set, the same label cannot be added twice
                                break
                        
                for person in persons:
                    # Draw the bbox
                    x1, y1, x2, y2 = person["bbox"]
                    cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0,255, 0), 3) #Draw the person bbox

                    # Filter labels to display based on the presence of "no" labels
                    no_labels = {label for label in person["labels"] if "no" in label}
                    yes_labels = {label for label in person["labels"] if "yes" in label}

                    # Decide which labels to display
                    display_labels = no_labels if no_labels else yes_labels

                    # Set the initial position for text (above the bbox)
                    text_position = (int(x1), int(y1) - 10)

                    for label in display_labels:
                        # Your existing logic to display the text
                        # Add a black rectangle for padding
                        text_width, text_height = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] 
                        rectangle_bgr = (0, 0, 0)
                        (text_width, text_height) = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1.2, thickness=1)[0]
                        box_coords = ((text_position[0], text_position[1]), (text_position[0] + text_width + 2, text_position[1] - text_height - 2))
                        cv2.rectangle(frame, box_coords[0], box_coords[1], rectangle_bgr, cv2.FILLED)

                        text_color = (0, 255, 0) if label in yes_labels else (0, 0, 255)  # Green for "yes", Red for "no"

                        # Add the text
                        cv2.putText(frame, label, text_position, cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
                        # Update text position for next label
                        text_position = (text_position[0], text_position[1] - 30)

                # Update tracker with the latest detections
                tracker.update_tracks(persons)

                for track_id, person in tracker.tracks.items():     
                    if last_track_id != track_id:
                        last_track_id = track_id
                    
                        detect = tracker.tracks.items()
                        detect = dict(detect)
                            
                        for key in detect:
                            labels = detect[key]['labels']
                            labels_list = [label for label in labels if label.strip()]

                        print(track_id,"------------track--------------") 
                        # print(track_id.count())      
                        # Draw the bbox and tracking info (including ID)
                        x1, y1, x2, y2 = person["bbox"]

                        for lable in labels_list:
                            label_text = lable.split(':')[0]
                            camera_event_obj = CameraEvent.objects.filter(camera_event = label_text).first()
                            current_time = datetime.now()
                            formatted_current_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
                            _, encoded_frame = cv2.imencode('.jpg', frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
                            image_file = ContentFile(encoded_frame, name="alert_image.jpg")
                            org_obj = Organization.objects.all().first()
                            loc_obj = Location.objects.filter(loc_name = camera.location.loc_name).first() if camera.location else ''
                            area_obj = Area.objects.filter(area_name = camera.area.area_name).first() if camera.area else ''
                            alert = Alert(
                                organization = org_obj,
                                camera = camera,
                                alert_message = f"Alert : {lable} detected at {formatted_current_time} in Device ID :{camera_id}",
                                frame=image_file,
                                camera_events = camera_event_obj
                            )
                            if loc_obj:
                                alert.location = loc_obj
                            if area_obj:
                                alert.area = area_obj
                            alert.save()

                            camera_incharge_list = Camera.objects.get(id=camera_id).camera_user.all()
                            for user in camera_incharge_list:
                                alert.camera_incharge.add(user)
                            alert.save()

                             

                            # Define the API endpoint
                            url = settings.URL+'/alert_image_upload'

                            # Replace this with the actual ContentFile you have
                            # image_content = ContentFile(b'image_file', 'test_image.jpg')

                            # Create a dictionary with the file
                             
                            files = {'image': ((alert.frame.name.replace('alert_images/', '')), encoded_frame)}

                            # Make the request
                            response = requests.post(url, files=files)

                            # Print the response
                            print(response.json())

                    # cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 3) #Draw the person bbox
                    cv2.putText(frame, f"ID: {track_id}", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) # print id on person bbox
                    check_and_trigger_alerts(track_id, person["labels"])

                # Resize and show the frame
                frame = cv2.resize(frame, (output_width, output_height))
                # out.write(frame)
                # cv2.imshow("yolov8", frame)

                # if cv2.waitKey(1) == ord('q'):
                #     break

            previous_person_count = current_person_count
    except Exception as e:
        print(f"Error processing device {camera_id}: {str(e)}")
    finally:
        # Clean up resources, if any.
        pass

import time
from multiprocessing import Process
from django.db import close_old_connections

class Command(BaseCommand):
    help = 'Triggers alerts based on video analysis'
    
    def handle(self, *args, **options):
        while True:  # Loop indefinitely
            close_old_connections()  # Close old database connections before starting new processes
            cameras = Camera.objects.filter(is_enabled=True).order_by('created_at')

            processes = []  # Keep track of the processes so we can terminate them later.

            try:
                for device in cameras:
                    print(f"devices: {device.id}")
                    # Start a new process for each device.
                    process = Process(target=process_device, args=(device.id,))
                    processes.append(process)
                    process.start()            

                print(f"Started {len(processes)} processes.")
                
                # Let the processes run for 10 seconds
                time.sleep(300)

            except Exception as e:
                print(f"Error occurred: {str(e)}")

            finally:
                # Terminate all the processes after 10 seconds
                for process in processes:
                    if process.is_alive():
                        process.terminate()  # Terminate the process
                        process.join()  # Ensure the process has finished
                        print(f"Terminated process for camera {device.id}")

                print("Restarting processes...")

                # Wait briefly before restarting to ensure all processes are fully terminated
                time.sleep(2)
