import cv2
from myprofile.tracker import *
from myprofile.models import *
from django.core.files.base import ContentFile
import django
from django.core.management.base import BaseCommand
from ultralytics import YOLO
from django.conf import settings
from multiprocessing import Process
import requests
from django.conf import settings
import io

base_dir = settings.BASE_DIR

def check_and_trigger_alerts(person_id, labels):

    no_items = [label for label in labels if "no" in label]
    if no_items:
        alert_message = f"ALERT for ID {person_id}: {' '.join(no_items)}"
        print(person_id,"-----------------person_id-----------------")
        # print(person_id.count())
        print(alert_message)

def process_device(camera_id):
    django.db.close_old_connections()
    # Set the output video resolution
    output_width, output_height = 1280, 720
    try:
        camera = Camera.objects.get(id=camera_id,is_enabled=True)
        fps = camera.fps

        fetch_events = LinkCameraEventToCamera.objects.filter(camera = camera, is_enabled=True).values_list('camera_events__camera_event', flat=True)
        
        camera_event_names = fetch_events
        event_dict = []
        for i in camera_event_names: event_dict.append(f"{i}: yes")
        for i in camera_event_names: event_dict.append(f"{i}: no")
        # event_dict = {}
        # for i in camera_event_names: event_dict.update({i:[]})
        camera_event_names_with_condition = [f"{name}: no" for name in camera_event_names]
        
        data =  event_dict
        data.append('Person')

        rules_classes = [
                'Helmet: yes', 'Jacket: yes', 'Gloves: yes', 'Shoes: yes', 'Person',
                'Helmet: no', 'Jacket: no', 'Gloves: no', 'Shoes: no'
            ]
        
        data_dict = {item: True for item in data}

        classes = [' '] * len(rules_classes)
        for i, item in enumerate(rules_classes):
            if item in data_dict:
                classes[i] = item
        
        cap = cv2.VideoCapture(camera.rtsp_url)

        # cap.set(cv2.CAP_PROP_FRAME_WIDTH, output_width)

        # fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        # out = cv2.VideoWriter("08.mp4", fourcc, 30.0, (output_width, output_height))


        model_path = str(base_dir)+"/model/best1.pt"
        model = YOLO(model_path)
        tracker = SimpleTracker()
        skip_frames = 5

        frame_skip_counter = 0

        previous_person_count = 0
        
        skip_counter = 0
        max_skips = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame_skip_counter += 1
            if frame_skip_counter <= skip_frames:
                continue
            frame_skip_counter = 0
            results = model(frame)[0]

            # classes = [
            #     ' ', ' ', ' ', ' ', 'Person',
            #     'Helmet: no', 'Jacket: no', 'Gloves: no', 'Shoes: no',

            # ]

            persons = []
            # Your existing detection logic here...
            for result in results:
                x1, y1, x2, y2, conf, class_idx = result.boxes.data.tolist()[0]
                class_idx = int(class_idx)
                class_label = classes[class_idx]

                if class_label == "Person":
                    # Create a new person object with an empty labels set
                    person = {"bbox": (x1, y1, x2, y2), "labels": set()}
                    persons.append(person)

            current_person_count = len(persons)
            last_track_id = 0
            if current_person_count != previous_person_count:
                # give alert beacuse of change in person count
                print(f"Person count changed from {previous_person_count} to {current_person_count}")
                for result in results:
                    x1, y1, x2, y2, conf, class_idx = result.boxes.data.tolist()[0]
                    class_idx = int(class_idx)
                    class_label = classes[class_idx]

                    if class_label != "Person":
                        cx, cy = (x1 + x2) / 2, (y1 + y2) / 2  # center of the PPE bbox
                        for person in persons:
                            px1, py1, px2, py2 = person["bbox"]
                            if px1 <= cx <= px2 and py1 <= cy <= py2:  # if the center of the PPE bbox is within the Person bbox
                                person["labels"].add(class_label)  # As "labels" is now a set, the same label cannot be added twice
                                break
                for person in persons:
                    # Draw the bbox
                    x1, y1, x2, y2 = person["bbox"]
                    cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0,255, 0), 3) #Draw the person bbox

                    # Filter labels to display based on the presence of "no" labels
                    no_labels = {label for label in person["labels"] if "no" in label}
                    yes_labels = {label for label in person["labels"] if "yes" in label}

                    # Decide which labels to display
                    display_labels = no_labels | yes_labels
                    
                    # Set the initial position for text (above the bbox)
                    text_position = (int(x1), int(y1) - 10)

                    for label in display_labels:
                        # Your existing logic to display the text
                        # Add a black rectangle for padding
                        text_width, text_height = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] 
                        rectangle_bgr = (0, 0, 0)
                        (text_width, text_height) = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 1.2, thickness=1)[0]
                        box_coords = ((text_position[0], text_position[1]), (text_position[0] + text_width + 2, text_position[1] - text_height - 2))
                        cv2.rectangle(frame, box_coords[0], box_coords[1], rectangle_bgr, cv2.FILLED)

                        text_color = (0, 255, 0) if label in yes_labels else (0, 0, 255)  # Green for "yes", Red for "no"

                        # Add the text
                        cv2.putText(frame, label, text_position, cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
                        # Update text position for next label
                        text_position = (text_position[0], text_position[1] - 30)

                # Update tracker with the latest detections
                tracker.update_tracks(persons)

                for track_id, person in tracker.tracks.items():     
                    if last_track_id != track_id:
                        last_track_id = track_id
                    
                        detect = tracker.tracks.items()
                        detect = dict(detect)
                            
                        for key in detect:
                            labels = detect[key]['labels']
                            labels_list = [label for label in labels if label.strip()]

                        print(track_id,"------------track--------------") 
                        # print(track_id.count())      
                        # Draw the bbox and tracking info (including ID)
                        x1, y1, x2, y2 = person["bbox"]

                        
                        for equ in camera_event_names_with_condition:
                            if not any(equ.split(':')[0] in item for item in labels_list):
                                labels_list.append(equ.split(':')[0]+': no_detection')
                        is_alert = False
                        for item in labels_list:
                            if item.split(':')[1] == " no":
                                is_alert = True
                        if is_alert:
                            alert_list = []
                            for lable in labels_list:
                                label_text = lable.split(':')[0]
                                camera_event_obj = CameraEvent.objects.filter(camera_event = label_text).first()
                                current_time = datetime.now()
                                formatted_current_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
                                _, encoded_frame = cv2.imencode('.jpg', frame, [int(cv2.IMWRITE_JPEG_QUALITY), 80])
                                image_file = ContentFile(encoded_frame, name="alert_image.jpg")
                                detect_choice = 'No'
                                if label.split(':')[1] == ' yes':
                                    detect_choice = 'Yes'
                                elif label.split(':')[1] == ' no_detection':
                                    detect_choice = 'NotDetected'
                                else:
                                    detect_choice = 'No'

                                alert = Alert(
                                    camera = camera,
                                    detection_choice = detect_choice,
                                    alert_message = f"Alert : {lable} detected at {formatted_current_time} in Device ID :{camera_id}",
                                    frame=image_file,
                                    camera_events = camera_event_obj
                                )
                                alert.save()

                                camera_incharge_list = Camera.objects.get(id=camera_id).camera_user.all()
                                for user in camera_incharge_list:
                                    alert.camera_incharge.add(user)
                                alert.save()

                                alert_list.append(alert)


                                # Define the API endpoint
                                url = settings.URL+'/alert_image_upload'

                                # Replace this with the actual ContentFile you have
                                # image_content = ContentFile(b'image_file', 'test_image.jpg')

                                # Create a dictionary with the file
                                
                                files = {'image': ((alert.frame.name.replace('alert_images/', '')), encoded_frame)}

                                # Make the request
                                response = requests.post(url, files=files)

                                # Print the response
                                print(response.json())

                            alert_person_tracking = AlertPersonTracking.objects.create(
                                tracking_id = int(track_id),
                            )
                            for alert in alert_list:
                                alert_person_tracking.alerts.add(alert)
                            alert_person_tracking.save()

                    # cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 3) #Draw the person bbox
                    cv2.putText(frame, f"ID: {track_id}", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) # print id on person bbox
                    check_and_trigger_alerts(track_id, person["labels"])

                # Resize and show the frame
                frame = cv2.resize(frame, (output_width, output_height))
                # out.write(frame)
                # cv2.imshow("yolov8", frame)

                if cv2.waitKey(1) == ord('q'):
                    break

            previous_person_count = current_person_count
    except Exception as e:
        print(f"Error processing device {camera_id}: {str(e)}")
    finally:
        # Clean up resources, if any.
        pass
class Command(BaseCommand):
    help = 'Triggers alerts based on video analysis'
    
    def handle(self, *args, **options):
         
        cameras = Camera.objects.filter(is_enabled=True).order_by('created_at')
        
        # processes = []  # Keep track of the processes so we can wait on them later.

        for camera in cameras:
            # if camera.id == 2:
            process_device(camera.id)

        # processes = []  # Keep track of the processes so we can wait on them later.

        # for device in cameras:
        #     # Start a new process for each device.
        #     process = Process(target=process_device, args=(device.id,))
        #     processes.append(process)
        #     process.start() 
        
        # print(processes)
        # for process in processes:
        #     process.join()  # Wait for all processes to complete.