# Copyright 2020-2021 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MediaPipe Objectron."""

import enum
from typing import List, Tuple, NamedTuple, Optional

import attr
import numpy as np

# pylint: disable=unused-import
from mediapipe.calculators.core import constant_side_packet_calculator_pb2
from mediapipe.calculators.core import gate_calculator_pb2
from mediapipe.calculators.core import split_vector_calculator_pb2
from mediapipe.calculators.tensor import image_to_tensor_calculator_pb2
from mediapipe.calculators.tensor import inference_calculator_pb2
from mediapipe.calculators.tensor import tensors_to_detections_calculator_pb2
from mediapipe.calculators.tensor import tensors_to_floats_calculator_pb2
from mediapipe.calculators.tensor import tensors_to_landmarks_calculator_pb2
from mediapipe.calculators.tflite import ssd_anchors_calculator_pb2
from mediapipe.calculators.util import association_calculator_pb2
from mediapipe.calculators.util import collection_has_min_size_calculator_pb2
from mediapipe.calculators.util import detection_label_id_to_text_calculator_pb2
from mediapipe.calculators.util import detections_to_rects_calculator_pb2
from mediapipe.calculators.util import landmark_projection_calculator_pb2
from mediapipe.calculators.util import local_file_contents_calculator_pb2
from mediapipe.calculators.util import non_max_suppression_calculator_pb2
from mediapipe.calculators.util import rect_transformation_calculator_pb2
from mediapipe.calculators.util import thresholding_calculator_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.modules.objectron.calculators import annotation_data_pb2
from mediapipe.modules.objectron.calculators import frame_annotation_to_rect_calculator_pb2
from mediapipe.modules.objectron.calculators import lift_2d_frame_annotation_to_3d_calculator_pb2
# pylint: enable=unused-import
from mediapipe.python.solution_base import SolutionBase
from mediapipe.python.solutions import download_utils


class BoxLandmark(enum.IntEnum):
  """The 9 3D box landmarks."""
  #
  #       3 + + + + + + + + 7
  #       +\                +\          UP
  #       + \               + \
  #       +  \              +  \        |
  #       +   4 + + + + + + + + 8       | y
  #       +   +             +   +       |
  #       +   +             +   +       |
  #       +   +     (0)     +   +       .------- x
  #       +   +             +   +        \
  #       1 + + + + + + + + 5   +         \
  #        \  +              \  +          \ z
  #         \ +               \ +           \
  #          \+                \+
  #           2 + + + + + + + + 6
  CENTER = 0
  BACK_BOTTOM_LEFT = 1
  FRONT_BOTTOM_LEFT = 2
  BACK_TOP_LEFT = 3
  FRONT_TOP_LEFT = 4
  BACK_BOTTOM_RIGHT = 5
  FRONT_BOTTOM_RIGHT = 6
  BACK_TOP_RIGHT = 7
  FRONT_TOP_RIGHT = 8

_BINARYPB_FILE_PATH = 'mediapipe/modules/objectron/objectron_cpu.binarypb'
BOX_CONNECTIONS = frozenset([
    (BoxLandmark.BACK_BOTTOM_LEFT, BoxLandmark.FRONT_BOTTOM_LEFT),
    (BoxLandmark.BACK_BOTTOM_LEFT, BoxLandmark.BACK_TOP_LEFT),
    (BoxLandmark.BACK_BOTTOM_LEFT, BoxLandmark.BACK_BOTTOM_RIGHT),
    (BoxLandmark.FRONT_BOTTOM_LEFT, BoxLandmark.FRONT_TOP_LEFT),
    (BoxLandmark.FRONT_BOTTOM_LEFT, BoxLandmark.FRONT_BOTTOM_RIGHT),
    (BoxLandmark.BACK_TOP_LEFT, BoxLandmark.FRONT_TOP_LEFT),
    (BoxLandmark.BACK_TOP_LEFT, BoxLandmark.BACK_TOP_RIGHT),
    (BoxLandmark.FRONT_TOP_LEFT, BoxLandmark.FRONT_TOP_RIGHT),
    (BoxLandmark.BACK_BOTTOM_RIGHT, BoxLandmark.FRONT_BOTTOM_RIGHT),
    (BoxLandmark.BACK_BOTTOM_RIGHT, BoxLandmark.BACK_TOP_RIGHT),
    (BoxLandmark.FRONT_BOTTOM_RIGHT, BoxLandmark.FRONT_TOP_RIGHT),
    (BoxLandmark.BACK_TOP_RIGHT, BoxLandmark.FRONT_TOP_RIGHT),
])


@attr.s(auto_attribs=True)
class ObjectronModel(object):
  model_path: str
  label_name: str


@attr.s(auto_attribs=True, frozen=True)
class ShoeModel(ObjectronModel):
  model_path: str = ('mediapipe/modules/objectron/'
                     'object_detection_3d_sneakers.tflite')
  label_name: str = 'Footwear'


@attr.s(auto_attribs=True, frozen=True)
class ChairModel(ObjectronModel):
  model_path: str = ('mediapipe/modules/objectron/'
                     'object_detection_3d_chair.tflite')
  label_name: str = 'Chair'


@attr.s(auto_attribs=True, frozen=True)
class CameraModel(ObjectronModel):
  model_path: str = ('mediapipe/modules/objectron/'
                     'object_detection_3d_camera.tflite')
  label_name: str = 'Camera'


@attr.s(auto_attribs=True, frozen=True)
class CupModel(ObjectronModel):
  model_path: str = ('mediapipe/modules/objectron/'
                     'object_detection_3d_cup.tflite')
  label_name: str = 'Coffee cup, Mug'

_MODEL_DICT = {
    'Shoe': ShoeModel(),
    'Chair': ChairModel(),
    'Cup': CupModel(),
    'Camera': CameraModel()
}


def _download_oss_objectron_models(objectron_model: str):
  """Downloads the objectron models from the MediaPipe Github repo if they don't exist in the package."""

  download_utils.download_oss_model(
      'mediapipe/modules/objectron/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite'
  )
  download_utils.download_oss_model(objectron_model)


def get_model_by_name(name: str) -> ObjectronModel:
  if name not in _MODEL_DICT:
    raise ValueError(f'{name} is not a valid model name for Objectron.')
  _download_oss_objectron_models(_MODEL_DICT[name].model_path)
  return _MODEL_DICT[name]


@attr.s(auto_attribs=True)
class ObjectronOutputs(object):
  landmarks_2d: landmark_pb2.NormalizedLandmarkList
  landmarks_3d: landmark_pb2.LandmarkList
  rotation: np.ndarray
  translation: np.ndarray
  scale: np.ndarray


class Objectron(SolutionBase):
  """MediaPipe Objectron.

  MediaPipe Objectron processes an RGB image and returns the 3D box landmarks
  and 2D rectangular bounding box of each detected object.
  """

  def __init__(self,
               static_image_mode: bool = False,
               max_num_objects: int = 5,
               min_detection_confidence: float = 0.5,
               min_tracking_confidence: float = 0.99,
               model_name: str = 'Shoe',
               focal_length: Tuple[float, float] = (1.0, 1.0),
               principal_point: Tuple[float, float] = (0.0, 0.0),
               image_size: Optional[Tuple[int, int]] = None,
               ):
    """Initializes a MediaPipe Objectron class.

    Args:
      static_image_mode: Whether to treat the input images as a batch of static
        and possibly unrelated images, or a video stream.
      max_num_objects: Maximum number of objects to detect.
      min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for object
        detection to be considered successful.
      min_tracking_confidence: Minimum confidence value ([0.0, 1.0]) for the
        box landmarks to be considered tracked successfully.
      model_name: Name of model to use for predicting box landmarks, currently
        support {'Shoe', 'Chair', 'Cup', 'Camera'}.
      focal_length: Camera focal length `(fx, fy)`, by default is defined in NDC
        space. To use focal length (fx_pixel, fy_pixel) in pixel space, users
        should provide image_size = (image_width, image_height) to enable
        conversions inside the API.
      principal_point: Camera principal point (px, py), by default is defined in
        NDC space. To use principal point (px_pixel, py_pixel) in pixel space,
        users should provide image_size = (image_width, image_height) to enable
        conversions inside the API.
      image_size (Optional): size (image_width, image_height) of the input image
        , ONLY needed when use focal_length and principal_point in pixel space.

    Raises:
      ConnectionError: If the objectron open source model can't be downloaded
        from the MediaPipe Github repo.
    """
    # Get Camera parameters.
    fx, fy = focal_length
    px, py = principal_point
    if image_size is not None:
      half_width = image_size[0] / 2.0
      half_height = image_size[1] / 2.0
      fx = fx / half_width
      fy = fy / half_height
      px = - (px - half_width) / half_width
      py = - (py - half_height) / half_height

    # Create and init model.
    model = get_model_by_name(model_name)
    super().__init__(
        binary_graph_path=_BINARYPB_FILE_PATH,
        side_inputs={
            'box_landmark_model_path': model.model_path,
            'allowed_labels': model.label_name,
            'max_num_objects': max_num_objects,
            'use_prev_landmarks': not static_image_mode,
        },
        calculator_params={
            ('objectdetectionoidv4subgraph'
             '__TensorsToDetectionsCalculator.min_score_thresh'):
                min_detection_confidence,
            ('boxlandmarksubgraph__ThresholdingCalculator'
             '.threshold'):
                min_tracking_confidence,
            ('Lift2DFrameAnnotationTo3DCalculator'
             '.normalized_focal_x'): fx,
            ('Lift2DFrameAnnotationTo3DCalculator'
             '.normalized_focal_y'): fy,
            ('Lift2DFrameAnnotationTo3DCalculator'
             '.normalized_principal_point_x'): px,
            ('Lift2DFrameAnnotationTo3DCalculator'
             '.normalized_principal_point_y'): py,
        },
        outputs=['detected_objects'])

  def process(self, image: np.ndarray) -> NamedTuple:
    """Processes an RGB image and returns the box landmarks and rectangular bounding box of each detected object.

    Args:
      image: An RGB image represented as a numpy ndarray.

    Raises:
      RuntimeError: If the underlying graph throws any error.
      ValueError: If the input image is not three channel RGB.

    Returns:
      A NamedTuple object with a "detected_objects" field that contains a list
      of detected 3D bounding boxes. Each detected box is represented as an
      "ObjectronOutputs" instance.
    """

    results = super().process(input_data={'image': image})
    if results.detected_objects:  # pytype: disable=attribute-error
      results.detected_objects = self._convert_format(results.detected_objects)  # type: ignore
    else:
      results.detected_objects = None  # pytype: disable=not-writable
    return results

  def _convert_format(
      self,
      inputs: annotation_data_pb2.FrameAnnotation) -> List[ObjectronOutputs]:
    new_outputs = list()
    for annotation in inputs.annotations:
      # Get 3d object pose.
      rotation = np.reshape(np.array(annotation.rotation), (3, 3))
      translation = np.array(annotation.translation)
      scale = np.array(annotation.scale)
      # Get 2d/3d landmakrs.
      landmarks_2d = landmark_pb2.NormalizedLandmarkList()
      landmarks_3d = landmark_pb2.LandmarkList()
      for keypoint in annotation.keypoints:
        point_2d = keypoint.point_2d
        landmarks_2d.landmark.add(x=point_2d.x, y=point_2d.y)
        point_3d = keypoint.point_3d
        landmarks_3d.landmark.add(x=point_3d.x, y=point_3d.y, z=point_3d.z)

      # Add to objectron outputs.
      new_outputs.append(ObjectronOutputs(landmarks_2d, landmarks_3d,
                                          rotation, translation, scale=scale))
    return new_outputs
