# Copyright 2020 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for mediapipe.python.solution_base."""

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from google.protobuf import text_format
from mediapipe.framework import calculator_pb2
from mediapipe.framework.formats import detection_pb2
from mediapipe.python import solution_base
from mediapipe.python.solution_base import PacketDataType

CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG = """
  input_stream: 'image_in'
  output_stream: 'image_out'
  node {
    name: 'ImageTransformation'
    calculator: 'ImageTransformationCalculator'
    input_stream: 'IMAGE:image_in'
    output_stream: 'IMAGE:image_out'
    options: {
      [mediapipe.ImageTransformationCalculatorOptions.ext] {
         output_width: 10
         output_height: 10
      }
    }
    node_options: {
      [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
         output_width: 10
         output_height: 10
      }
    }
  }
"""


class SolutionBaseTest(parameterized.TestCase):

  def test_invalid_initialization_arguments(self):
    with self.assertRaisesRegex(
        ValueError,
        'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
    ):
      solution_base.SolutionBase()
    with self.assertRaisesRegex(
        ValueError,
        'Must provide exactly one of \'binary_graph_path\' or \'graph_config\'.'
    ):
      solution_base.SolutionBase(
          graph_config=calculator_pb2.CalculatorGraphConfig(),
          binary_graph_path='/tmp/no_such.binarypb')

  @parameterized.named_parameters(('no_graph_input_output_stream', """
      node {
        calculator: 'PassThroughCalculator'
        input_stream: 'in'
        output_stream: 'out'
      }
      """, RuntimeError, 'does not have a corresponding output stream.'),
                                  ('calcualtor_io_mismatch', """
      node {
        calculator: 'PassThroughCalculator'
        input_stream: 'in'
        input_stream: 'in2'
        output_stream: 'out'
      }
      """, ValueError, 'must use matching tags and indexes.'),
                                  ('unkown_registered_stream_type_name', """
      input_stream: 'in'
      output_stream: 'out'
      node {
        calculator: 'PassThroughCalculator'
        input_stream: 'in'
        output_stream: 'out'
      }
      """, RuntimeError, 'Unable to find the type for stream \"in\".'))
  def test_invalid_config(self, text_config, error_type, error_message):
    config_proto = text_format.Parse(text_config,
                                     calculator_pb2.CalculatorGraphConfig())
    with self.assertRaisesRegex(error_type, error_message):
      solution_base.SolutionBase(graph_config=config_proto)

  def test_valid_input_data_type_proto(self):
    text_config = """
      input_stream: 'input_detections'
      output_stream: 'output_detections'
      node {
        calculator: 'DetectionUniqueIdCalculator'
        input_stream: 'DETECTION_LIST:input_detections'
        output_stream: 'DETECTION_LIST:output_detections'
      }
    """
    config_proto = text_format.Parse(text_config,
                                     calculator_pb2.CalculatorGraphConfig())
    with solution_base.SolutionBase(graph_config=config_proto) as solution:
      input_detections = detection_pb2.DetectionList()
      detection_1 = input_detections.detection.add()
      text_format.Parse('score: 0.5', detection_1)
      detection_2 = input_detections.detection.add()
      text_format.Parse('score: 0.8', detection_2)
      results = solution.process({'input_detections': input_detections})
      self.assertTrue(hasattr(results, 'output_detections'))
      self.assertLen(results.output_detections.detection, 2)
      expected_detection_1 = detection_pb2.Detection()
      text_format.Parse('score: 0.5, detection_id: 1', expected_detection_1)
      expected_detection_2 = detection_pb2.Detection()
      text_format.Parse('score: 0.8, detection_id: 2', expected_detection_2)
      self.assertEqual(results.output_detections.detection[0],
                       expected_detection_1)
      self.assertEqual(results.output_detections.detection[1],
                       expected_detection_2)

  def test_invalid_input_data_type_proto_vector(self):
    text_config = """
      input_stream: 'input_detections'
      output_stream: 'output_detections'
      node {
        calculator: 'DetectionUniqueIdCalculator'
        input_stream: 'DETECTIONS:input_detections'
        output_stream: 'DETECTIONS:output_detections'
      }
    """
    config_proto = text_format.Parse(text_config,
                                     calculator_pb2.CalculatorGraphConfig())
    with solution_base.SolutionBase(graph_config=config_proto) as solution:
      detection = detection_pb2.Detection()
      text_format.Parse('score: 0.5', detection)
      with self.assertRaisesRegex(
          NotImplementedError,
          'SolutionBase can only process non-audio and non-proto-list data. '
          + 'PROTO_LIST type is not supported.'
      ):
        solution.process({'input_detections': detection})

  def test_invalid_input_image_data(self):
    text_config = """
      input_stream: 'image_in'
      output_stream: 'image_out'
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_in'
        output_stream: 'IMAGE:transformed_image_in'
      }
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:transformed_image_in'
        output_stream: 'IMAGE:image_out'
      }
    """
    config_proto = text_format.Parse(text_config,
                                     calculator_pb2.CalculatorGraphConfig())
    with solution_base.SolutionBase(graph_config=config_proto) as solution:
      with self.assertRaisesRegex(
          ValueError, 'Input image must contain three channel rgb data.'):
        solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4))

  @parameterized.named_parameters(('graph_without_side_packets', """
      input_stream: 'image_in'
      output_stream: 'image_out'
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_in'
        output_stream: 'IMAGE:transformed_image_in'
      }
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:transformed_image_in'
        output_stream: 'IMAGE:image_out'
      }
      """, None), ('graph_with_side_packets', """
      input_stream: 'image_in'
      input_side_packet: 'allow_signal'
      input_side_packet: 'rotation_degrees'
      output_stream: 'image_out'
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_in'
        input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
        output_stream: 'IMAGE:transformed_image_in'
      }
      node {
        calculator: 'GateCalculator'
        input_stream: 'transformed_image_in'
        input_side_packet: 'ALLOW:allow_signal'
        output_stream: 'image_out_to_transform'
      }
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_out_to_transform'
        input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
        output_stream: 'IMAGE:image_out'
      }""", {
          'allow_signal': True,
          'rotation_degrees': 0
      }))
  def test_solution_process(self, text_config, side_inputs):
    self._process_and_verify(
        config_proto=text_format.Parse(text_config,
                                       calculator_pb2.CalculatorGraphConfig()),
        side_inputs=side_inputs)

  def test_invalid_calculator_options(self):
    text_config = """
      input_stream: 'image_in'
      output_stream: 'image_out'
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_in'
        output_stream: 'IMAGE:transformed_image_in'
      }
      node {
        name: 'SignalGate'
        calculator: 'GateCalculator'
        input_stream: 'transformed_image_in'
        input_side_packet: 'ALLOW:allow_signal'
        output_stream: 'image_out_to_transform'
      }
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_out_to_transform'
        output_stream: 'IMAGE:image_out'
      }
    """
    config_proto = text_format.Parse(text_config,
                                     calculator_pb2.CalculatorGraphConfig())
    with self.assertRaisesRegex(
        ValueError,
        'Modifying the calculator options of SignalGate is not supported.'):
      solution_base.SolutionBase(
          graph_config=config_proto,
          calculator_params={'SignalGate.invalid_field': 'I am invalid'})

  def test_calculator_has_both_options_and_node_options(self):
    config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
                                     calculator_pb2.CalculatorGraphConfig())
    with self.assertRaisesRegex(ValueError,
                                'has both options and node_options fields.'):
      solution_base.SolutionBase(
          graph_config=config_proto,
          calculator_params={
              'ImageTransformation.output_width': 0,
              'ImageTransformation.output_height': 0
          })

  def test_modifying_calculator_proto2_options(self):
    config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
                                     calculator_pb2.CalculatorGraphConfig())
    # To test proto2 options only, remove the proto3 node_options field from the
    # graph config.
    self.assertEqual('ImageTransformation', config_proto.node[0].name)
    config_proto.node[0].ClearField('node_options')
    self._process_and_verify(
        config_proto=config_proto,
        calculator_params={
            'ImageTransformation.output_width': 0,
            'ImageTransformation.output_height': 0
        })

  def test_modifying_calculator_proto3_node_options(self):
    config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
                                     calculator_pb2.CalculatorGraphConfig())
    # To test proto3 node options only, remove the proto2 options field from the
    # graph config.
    self.assertEqual('ImageTransformation', config_proto.node[0].name)
    config_proto.node[0].ClearField('options')
    self._process_and_verify(
        config_proto=config_proto,
        calculator_params={
            'ImageTransformation.output_width': 0,
            'ImageTransformation.output_height': 0
        })

  def test_adding_calculator_options(self):
    config_proto = text_format.Parse(CALCULATOR_OPTIONS_TEST_GRAPH_CONFIG,
                                     calculator_pb2.CalculatorGraphConfig())
    # To test a calculator with no options field, remove both proto2 options and
    # proto3 node_options fields from the graph config.
    self.assertEqual('ImageTransformation', config_proto.node[0].name)
    config_proto.node[0].ClearField('options')
    config_proto.node[0].ClearField('node_options')
    self._process_and_verify(
        config_proto=config_proto,
        calculator_params={
            'ImageTransformation.output_width': 0,
            'ImageTransformation.output_height': 0
        })

  @parameterized.named_parameters(('graph_without_side_packets', """
      input_stream: 'image_in'
      output_stream: 'image_out'
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_in'
        output_stream: 'IMAGE:transformed_image_in'
      }
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:transformed_image_in'
        output_stream: 'IMAGE:image_out'
      }
      """, None), ('graph_with_side_packets', """
      input_stream: 'image_in'
      input_side_packet: 'allow_signal'
      input_side_packet: 'rotation_degrees'
      output_stream: 'image_out'
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_in'
        input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
        output_stream: 'IMAGE:transformed_image_in'
      }
      node {
        calculator: 'GateCalculator'
        input_stream: 'transformed_image_in'
        input_side_packet: 'ALLOW:allow_signal'
        output_stream: 'image_out_to_transform'
      }
      node {
        calculator: 'ImageTransformationCalculator'
        input_stream: 'IMAGE:image_out_to_transform'
        input_side_packet: 'ROTATION_DEGREES:rotation_degrees'
        output_stream: 'IMAGE:image_out'
      }""", {
          'allow_signal': True,
          'rotation_degrees': 0
      }))
  def test_solution_reset(self, text_config, side_inputs):
    config_proto = text_format.Parse(text_config,
                                     calculator_pb2.CalculatorGraphConfig())
    input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
    with solution_base.SolutionBase(
        graph_config=config_proto, side_inputs=side_inputs) as solution:
      for _ in range(20):
        outputs = solution.process(input_image)
        self.assertTrue(np.array_equal(input_image, outputs.image_out))
        solution.reset()

  def test_solution_stream_type_hints(self):
    text_config = """
      input_stream: 'union_type_image_in'
      output_stream: 'image_type_out'
      node {
        calculator: 'ToImageCalculator'
        input_stream: 'IMAGE:union_type_image_in'
        output_stream: 'IMAGE:image_type_out'
      }
    """
    config_proto = text_format.Parse(text_config,
                                     calculator_pb2.CalculatorGraphConfig())
    input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
    with solution_base.SolutionBase(
        graph_config=config_proto,
        stream_type_hints={'union_type_image_in': PacketDataType.IMAGE
                          }) as solution:
      for _ in range(20):
        outputs = solution.process(input_image)
        self.assertTrue(np.array_equal(input_image, outputs.image_type_out))
    with solution_base.SolutionBase(
        graph_config=config_proto,
        stream_type_hints={'union_type_image_in': PacketDataType.IMAGE_FRAME
                          }) as solution2:
      for _ in range(20):
        outputs = solution2.process(input_image)
        self.assertTrue(np.array_equal(input_image, outputs.image_type_out))

  def _process_and_verify(self,
                          config_proto,
                          side_inputs=None,
                          calculator_params=None):
    input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3)
    with solution_base.SolutionBase(
        graph_config=config_proto,
        side_inputs=side_inputs,
        calculator_params=calculator_params) as solution:
      outputs = solution.process(input_image)
      outputs2 = solution.process({'image_in': input_image})
    self.assertTrue(np.array_equal(input_image, outputs.image_out))
    self.assertTrue(np.array_equal(input_image, outputs2.image_out))


if __name__ == '__main__':
  absltest.main()
