"""
Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
"""

import importlib
import json
import logging
import logging.config
import os
import re
import tempfile
import time
import traceback
import unittest
from io import StringIO
from tempfile import NamedTemporaryFile
from unittest.mock import MagicMock, Mock, patch, ANY

import awslambdaric.bootstrap as bootstrap
from awslambdaric.lambda_runtime_exception import FaultException
from awslambdaric.lambda_runtime_log_utils import (
    LogFormat,
    _get_log_level_from_env_var,
    JsonFormatter,
)
from awslambdaric.lambda_runtime_marshaller import LambdaMarshaller
from awslambdaric.lambda_literals import (
    lambda_unhandled_exception_warning_message,
)
import snapshot_restore_py


class TestUpdateXrayEnv(unittest.TestCase):
    def setUp(self):
        self.org_os_environ = os.environ

    def tearDown(self):
        os.environ = self.org_os_environ

    def test_update_xray_env_variable_empty(self):
        os.environ = {}
        bootstrap.update_xray_env_variable(None)
        self.assertEqual(os.environ.get("_X_AMZN_TRACE_ID"), None)

    def test_update_xray_env_variable_remove_old_value(self):
        os.environ = {"_X_AMZN_TRACE_ID": "old-id"}
        bootstrap.update_xray_env_variable(None)
        self.assertEqual(os.environ.get("_X_AMZN_TRACE_ID"), None)

    def test_update_xray_env_variable_new_value(self):
        os.environ = {}
        bootstrap.update_xray_env_variable("new-id")
        self.assertEqual(os.environ.get("_X_AMZN_TRACE_ID"), "new-id")

    def test_update_xray_env_variable_overwrite(self):
        os.environ = {"_X_AMZN_TRACE_ID": "old-id"}
        bootstrap.update_xray_env_variable("new-id")
        self.assertEqual(os.environ.get("_X_AMZN_TRACE_ID"), "new-id")


class TestHandleEventRequest(unittest.TestCase):
    def setUp(self):
        self.lambda_runtime = Mock()
        self.lambda_runtime.marshaller = LambdaMarshaller()
        self.event_body = '"event_body"'
        self.working_directory = os.getcwd()

        logging.getLogger().handlers.clear()

    def tearDown(self) -> None:
        logging.getLogger().handlers.clear()
        logging.getLogger().level = logging.NOTSET

        return super().tearDown()

    @staticmethod
    def dummy_handler(json_input, lambda_context):
        return {"input": json_input, "aws_request_id": lambda_context.aws_request_id}

    def test_handle_event_request_happy_case(self):
        bootstrap.handle_event_request(
            self.lambda_runtime,
            self.dummy_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        self.lambda_runtime.post_invocation_result.assert_called_once_with(
            "invoke_id",
            '{"input": "event_body", "aws_request_id": "invoke_id"}',
            "application/json",
        )

    def test_handle_event_request_invalid_client_context(self):
        expected_response = {
            "errorType": "Runtime.LambdaContextUnmarshalError",
            "errorMessage": "Unable to parse Client Context JSON: Expecting value: line 1 column 1 (char 0)",
        }
        bootstrap.handle_event_request(
            self.lambda_runtime,
            self.dummy_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            "invalid_client_context_not_json",
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        args, _ = self.lambda_runtime.post_invocation_error.call_args
        error_response = json.loads(args[1])
        self.assertEqual(args[0], "invoke_id")
        self.assertTrue(
            expected_response.items() <= error_response.items(),
            "Response doesn't contain all the necessary fields\nExpected: {}\nActual: {}".format(
                expected_response, error_response
            ),
        )
        self.assertEqual(
            json.loads(args[2]),
            {
                "working_directory": self.working_directory,
                "exceptions": [
                    {
                        "message": expected_response["errorMessage"],
                        "type": "LambdaValidationError",
                        "stack": [],
                    }
                ],
                "paths": [],
            },
        )

    def test_handle_event_request_invalid_cognito_idenity(self):
        expected_response = {
            "errorType": "Runtime.LambdaContextUnmarshalError",
            "errorMessage": "Unable to parse Cognito Identity JSON: Expecting value: line 1 column 1 (char 0)",
        }
        bootstrap.handle_event_request(
            self.lambda_runtime,
            self.dummy_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            "invalid_cognito_identity",
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        args, _ = self.lambda_runtime.post_invocation_error.call_args
        error_response = json.loads(args[1])
        self.assertEqual(args[0], "invoke_id")
        self.assertTrue(
            expected_response.items() <= error_response.items(),
            "Response doesn't contain all the necessary fields\nExpected: {}\nActual: {}".format(
                expected_response, error_response
            ),
        )
        self.assertEqual(
            json.loads(args[2]),
            {
                "working_directory": self.working_directory,
                "exceptions": [
                    {
                        "message": expected_response["errorMessage"],
                        "type": "LambdaValidationError",
                        "stack": [],
                    }
                ],
                "paths": [],
            },
        )

    def test_handle_event_request_invalid_event_body(self):
        expected_response = {
            "errorType": "Runtime.UnmarshalError",
            "errorMessage": "Unable to unmarshal input: Expecting value: line 1 column 1 (char 0)",
        }
        invalid_event_body = "not_valid_json"
        bootstrap.handle_event_request(
            self.lambda_runtime,
            self.dummy_handler,
            "invoke_id",
            invalid_event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        args, _ = self.lambda_runtime.post_invocation_error.call_args
        error_response = json.loads(args[1])
        self.assertEqual(args[0], "invoke_id")
        self.assertTrue(
            expected_response.items() <= error_response.items(),
            "Response doesn't contain all the necessary fields\nExpected: {}\nActual: {}".format(
                expected_response, error_response
            ),
        )
        self.assertEqual(
            json.loads(args[2]),
            {
                "working_directory": self.working_directory,
                "exceptions": [
                    {
                        "message": expected_response["errorMessage"],
                        "type": "LambdaValidationError",
                        "stack": [],
                    }
                ],
                "paths": [],
            },
        )

    def test_handle_event_request_invalid_response(self):
        def invalid_json_response(json_input, lambda_context):
            return type("obj", (object,), {"propertyName": "propertyValue"})

        expected_response = {
            "errorType": "Runtime.MarshalError",
            "errorMessage": "Unable to marshal response: Object of type type is not JSON serializable",
        }
        bootstrap.handle_event_request(
            self.lambda_runtime,
            invalid_json_response,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        args, _ = self.lambda_runtime.post_invocation_error.call_args
        error_response = json.loads(args[1])
        self.assertEqual(args[0], "invoke_id")
        self.assertTrue(
            expected_response.items() <= error_response.items(),
            "Expected response is not a subset of the actual response\nExpected: {}\nActual: {}".format(
                expected_response, error_response
            ),
        )
        self.assertEqual(
            json.loads(args[2]),
            {
                "working_directory": self.working_directory,
                "exceptions": [
                    {
                        "message": expected_response["errorMessage"],
                        "type": "LambdaValidationError",
                        "stack": [],
                    }
                ],
                "paths": [],
            },
        )

    def test_handle_event_request_custom_exception(self):
        def raise_exception_handler(json_input, lambda_context):
            class MyError(Exception):
                def __init__(self, message):
                    self.message = message

            raise MyError("My error")

        expected_response = {"errorType": "MyError", "errorMessage": "My error"}
        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        args, _ = self.lambda_runtime.post_invocation_error.call_args
        error_response = json.loads(args[1])
        self.assertEqual(args[0], "invoke_id")
        self.assertTrue(
            expected_response.items() <= error_response.items(),
            "Expected response is not a subset of the actual response\nExpected: {}\nActual: {}".format(
                expected_response, error_response
            ),
        )
        xray_fault = json.loads(args[2])
        self.assertEqual(xray_fault["working_directory"], self.working_directory)
        self.assertEqual(len(xray_fault["exceptions"]), 1)
        self.assertEqual(
            xray_fault["exceptions"][0]["message"], expected_response["errorMessage"]
        )
        self.assertEqual(
            xray_fault["exceptions"][0]["type"], expected_response["errorType"]
        )
        self.assertEqual(len(xray_fault["exceptions"][0]["stack"]), 1)
        self.assertEqual(
            xray_fault["exceptions"][0]["stack"][0]["label"], "raise_exception_handler"
        )
        self.assertIsInstance(xray_fault["exceptions"][0]["stack"][0]["line"], int)
        self.assertTrue(
            xray_fault["exceptions"][0]["stack"][0]["path"].endswith(
                os.path.relpath(__file__)
            )
        )
        self.assertEqual(len(xray_fault["paths"]), 1)
        self.assertTrue(xray_fault["paths"][0].endswith(os.path.relpath(__file__)))

    def test_handle_event_request_custom_empty_error_message_exception(self):
        def raise_exception_handler(json_input, lambda_context):
            class MyError(Exception):
                def __init__(self, message):
                    self.message = message

            raise MyError("")

        expected_response = {"errorType": "MyError", "errorMessage": ""}
        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        args, _ = self.lambda_runtime.post_invocation_error.call_args
        error_response = json.loads(args[1])
        self.assertEqual(args[0], "invoke_id")
        self.assertTrue(
            expected_response.items() <= error_response.items(),
            "Expected response is not a subset of the actual response\nExpected: {}\nActual: {}".format(
                expected_response, error_response
            ),
        )
        xray_fault = json.loads(args[2])
        self.assertEqual(xray_fault["working_directory"], self.working_directory)
        self.assertEqual(len(xray_fault["exceptions"]), 1)
        self.assertEqual(
            xray_fault["exceptions"][0]["message"], expected_response["errorMessage"]
        )
        self.assertEqual(
            xray_fault["exceptions"][0]["type"], expected_response["errorType"]
        )
        self.assertEqual(len(xray_fault["exceptions"][0]["stack"]), 1)
        self.assertEqual(
            xray_fault["exceptions"][0]["stack"][0]["label"], "raise_exception_handler"
        )
        self.assertIsInstance(xray_fault["exceptions"][0]["stack"][0]["line"], int)
        self.assertTrue(
            xray_fault["exceptions"][0]["stack"][0]["path"].endswith(
                os.path.relpath(__file__)
            )
        )
        self.assertEqual(len(xray_fault["paths"]), 1)
        self.assertTrue(xray_fault["paths"][0].endswith(os.path.relpath(__file__)))

    def test_handle_event_request_no_module(self):
        def unable_to_import_module(json_input, lambda_context):
            import invalid_module  # noqa: F401

        expected_response = {
            "errorType": "ModuleNotFoundError",
            "errorMessage": "No module named 'invalid_module'",
        }
        bootstrap.handle_event_request(
            self.lambda_runtime,
            unable_to_import_module,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        args, _ = self.lambda_runtime.post_invocation_error.call_args
        error_response = json.loads(args[1])
        self.assertEqual(args[0], "invoke_id")
        self.assertTrue(
            expected_response.items() <= error_response.items(),
            "Expected response is not a subset of the actual response\nExpected: {}\nActual: {}".format(
                expected_response, error_response
            ),
        )

    def test_handle_event_request_fault_exception(self):
        def raise_exception_handler(json_input, lambda_context):
            try:
                import invalid_module  # noqa: F401
            except ImportError:
                raise FaultException(
                    "FaultExceptionType",
                    "Fault exception msg",
                    ["trace_line1\ntrace_line2", "trace_line3\ntrace_line4"],
                )

        expected_response = {
            "errorType": "FaultExceptionType",
            "errorMessage": "Fault exception msg",
            "requestId": "invoke_id",
            "stackTrace": ["trace_line1\ntrace_line2", "trace_line3\ntrace_line4"],
        }
        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        args, _ = self.lambda_runtime.post_invocation_error.call_args
        error_response = json.loads(args[1])
        self.assertEqual(args[0], "invoke_id")
        self.assertEqual(error_response.items(), expected_response.items())
        self.assertEqual(
            json.loads(args[2]),
            {
                "working_directory": self.working_directory,
                "exceptions": [
                    {
                        "message": expected_response["errorMessage"],
                        "type": "LambdaValidationError",
                        "stack": [],
                    }
                ],
                "paths": [],
            },
        )

    @patch("sys.stdout", new_callable=StringIO)
    def test_handle_event_request_fault_exception_logging(self, mock_stdout):
        def raise_exception_handler(json_input, lambda_context):
            try:
                import invalid_module  # noqa: F401
            except ImportError:
                raise bootstrap.FaultException(
                    "FaultExceptionType",
                    "Fault exception msg",
                    traceback.format_list(
                        [
                            ("spam.py", 3, "<module>", "spam.eggs()"),
                            ("eggs.py", 42, "eggs", 'return "bacon"'),
                        ]
                    ),
                )

        logging.getLogger().addHandler(logging.StreamHandler(mock_stdout))

        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )

        # NOTE: Indentation characters are NO-BREAK SPACE (U+00A0) not SPACE (U+0020)
        error_logs = (
            lambda_unhandled_exception_warning_message
            + "\n"
            + "[ERROR] FaultExceptionType: Fault exception msg\r"
        )
        error_logs += "Traceback (most recent call last):\r"
        error_logs += '  File "spam.py", line 3, in <module>\r'
        error_logs += "    spam.eggs()\r"
        error_logs += '  File "eggs.py", line 42, in eggs\r'
        error_logs += '    return "bacon"\n'

        self.assertEqual(mock_stdout.getvalue(), error_logs)

    @patch("sys.stdout", new_callable=StringIO)
    def test_handle_event_request_fault_exception_logging_notrace(self, mock_stdout):
        def raise_exception_handler(json_input, lambda_context):
            try:
                import invalid_module  # noqa: F401
            except ImportError:
                raise bootstrap.FaultException(
                    "FaultExceptionType", "Fault exception msg", None
                )

        logging.getLogger().addHandler(logging.StreamHandler(mock_stdout))

        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        error_logs = (
            lambda_unhandled_exception_warning_message
            + "\n"
            + "[ERROR] FaultExceptionType: Fault exception msg\rTraceback (most recent call last):\n"
        )

        self.assertEqual(mock_stdout.getvalue(), error_logs)

    @patch("sys.stdout", new_callable=StringIO)
    def test_handle_event_request_fault_exception_logging_nomessage_notrace(
        self, mock_stdout
    ):
        def raise_exception_handler(json_input, lambda_context):
            try:
                import invalid_module  # noqa: F401
            except ImportError:
                raise bootstrap.FaultException("FaultExceptionType", None, None)

        logging.getLogger().addHandler(logging.StreamHandler(mock_stdout))

        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        error_logs = (
            lambda_unhandled_exception_warning_message
            + "\n"
            + "[ERROR] FaultExceptionType\rTraceback (most recent call last):\n"
        )

        self.assertEqual(mock_stdout.getvalue(), error_logs)

    @patch("sys.stdout", new_callable=StringIO)
    def test_handle_event_request_fault_exception_logging_notype_notrace(
        self, mock_stdout
    ):
        def raise_exception_handler(json_input, lambda_context):
            try:
                import invalid_module  # noqa: F401
            except ImportError:
                raise bootstrap.FaultException(None, "Fault exception msg", None)

        logging.getLogger().addHandler(logging.StreamHandler(mock_stdout))

        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        error_logs = (
            lambda_unhandled_exception_warning_message
            + "\n"
            + "[ERROR] Fault exception msg\rTraceback (most recent call last):\n"
        )

        self.assertEqual(mock_stdout.getvalue(), error_logs)

    @patch("sys.stdout", new_callable=StringIO)
    def test_handle_event_request_fault_exception_logging_notype_nomessage(
        self, mock_stdout
    ):
        def raise_exception_handler(json_input, lambda_context):
            try:
                import invalid_module  # noqa: F401
            except ImportError:
                raise bootstrap.FaultException(
                    None,
                    None,
                    traceback.format_list(
                        [
                            ("spam.py", 3, "<module>", "spam.eggs()"),
                            ("eggs.py", 42, "eggs", 'return "bacon"'),
                        ]
                    ),
                )

        logging.getLogger().addHandler(logging.StreamHandler(mock_stdout))

        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )
        error_logs = lambda_unhandled_exception_warning_message + "\n[ERROR]\r"
        error_logs += "Traceback (most recent call last):\r"
        error_logs += '  File "spam.py", line 3, in <module>\r'
        error_logs += "    spam.eggs()\r"
        error_logs += '  File "eggs.py", line 42, in eggs\r'
        error_logs += '    return "bacon"\n'

        self.assertEqual(mock_stdout.getvalue(), error_logs)

    @patch("sys.stdout", new_callable=StringIO)
    def test_handle_event_request_fault_exception_logging_in_json(self, mock_stdout):
        def raise_exception_handler(json_input, lambda_context):
            try:
                import invalid_module  # noqa: F401
            except ImportError:
                raise bootstrap.FaultException("FaultExceptionType", None, None)

        logging_handler = logging.StreamHandler(mock_stdout)
        logging_handler.setFormatter(JsonFormatter())
        logging.getLogger().addHandler(logging_handler)

        bootstrap.handle_event_request(
            self.lambda_runtime,
            raise_exception_handler,
            "invoke_id",
            self.event_body,
            "application/json",
            {},
            {},
            "invoked_function_arn",
            0,
            bootstrap.StandardLogSink(),
        )

        stdout_value = mock_stdout.getvalue()
        received_warning = stdout_value.split("\n")[0]
        received_rest = stdout_value[len(received_warning) + 1 :]

        warning = json.loads(received_warning)
        self.assertEqual(warning["level"], "WARNING")
        self.assertEqual(warning["message"], lambda_unhandled_exception_warning_message)
        self.assertEqual(warning["logger"], "root")
        self.assertIn("timestamp", warning)

        # this line is not in json because of the way the test runtime is bootstrapped
        error_logs = (
            "\n[ERROR] FaultExceptionType\rTraceback (most recent call last):\n"
        )

        self.assertEqual(received_rest, error_logs)


class TestXrayFault(unittest.TestCase):
    def test_make_xray(self):
        class CustomException(Exception):
            def __init__(self):
                pass

        actual = bootstrap.make_xray_fault(
            CustomException.__name__,
            "test_message",
            "working/dir",
            [["test.py", 28, "test_method", "does_not_matter"]],
        )

        self.assertEqual(actual["working_directory"], "working/dir")
        self.assertEqual(actual["paths"], ["test.py"])
        self.assertEqual(len(actual["exceptions"]), 1)
        self.assertEqual(actual["exceptions"][0]["message"], "test_message")
        self.assertEqual(actual["exceptions"][0]["type"], "CustomException")
        self.assertEqual(len(actual["exceptions"][0]["stack"]), 1)
        self.assertEqual(actual["exceptions"][0]["stack"][0]["label"], "test_method")
        self.assertEqual(actual["exceptions"][0]["stack"][0]["path"], "test.py")
        self.assertEqual(actual["exceptions"][0]["stack"][0]["line"], 28)

    def test_make_xray_with_multiple_tb(self):
        class CustomException(Exception):
            def __init__(self):
                pass

        actual = bootstrap.make_xray_fault(
            CustomException.__name__,
            "test_message",
            "working/dir",
            [
                ["test.py", 28, "test_method", ""],
                ["another_test.py", 2718, "another_test_method", ""],
            ],
        )

        self.assertEqual(len(actual["exceptions"]), 1)
        self.assertEqual(len(actual["exceptions"][0]["stack"]), 2)
        self.assertEqual(actual["exceptions"][0]["stack"][0]["label"], "test_method")
        self.assertEqual(actual["exceptions"][0]["stack"][0]["path"], "test.py")
        self.assertEqual(actual["exceptions"][0]["stack"][0]["line"], 28)
        self.assertEqual(
            actual["exceptions"][0]["stack"][1]["label"], "another_test_method"
        )
        self.assertEqual(actual["exceptions"][0]["stack"][1]["path"], "another_test.py")
        self.assertEqual(actual["exceptions"][0]["stack"][1]["line"], 2718)


class TestGetEventHandler(unittest.TestCase):
    class FaultExceptionMatcher(BaseException):
        def __init__(self, msg, exception_type=None, trace_pattern=None):
            self.msg = msg
            self.exception_type = exception_type
            self.trace = (
                trace_pattern if trace_pattern is None else re.compile(trace_pattern)
            )

        def __eq__(self, other):
            trace_matches = True
            if self.trace is not None:
                # Validate that trace is an array
                if not isinstance(other.trace, list):
                    trace_matches = False
                elif not self.trace.match("".join(other.trace)):
                    trace_matches = False

            return (
                self.msg in other.msg
                and self.exception_type == other.exception_type
                and trace_matches
            )

    def test_get_event_handler_bad_handler(self):
        handler_name = "bad_handler"
        with self.assertRaises(FaultException) as cm:
            response_handler = bootstrap._get_handler(handler_name)
        returned_exception = cm.exception
        self.assertEqual(
            self.FaultExceptionMatcher(
                "Bad handler 'bad_handler': not enough values to unpack (expected 2, got 1)",
                "Runtime.MalformedHandlerName",
            ),
            returned_exception,
        )

    def test_get_event_handler_import_error(self):
        handler_name = "no_module.handler"
        with self.assertRaises(FaultException) as cm:
            response_handler = bootstrap._get_handler(handler_name)
        returned_exception = cm.exception
        self.assertEqual(
            self.FaultExceptionMatcher(
                "Unable to import module 'no_module': No module named 'no_module'",
                "Runtime.ImportModuleError",
            ),
            returned_exception,
        )

    def test_get_event_handler_syntax_error(self):
        importlib.invalidate_caches()
        with tempfile.NamedTemporaryFile(
            suffix=".py", dir=".", delete=False
        ) as tmp_file:
            tmp_file.write(
                b"def syntax_error()\n\tprint('syntax error, no colon after function')"
            )
            tmp_file.flush()

            filename_w_ext = os.path.basename(tmp_file.name)
            filename, _ = os.path.splitext(filename_w_ext)
            handler_name = "{}.syntax_error".format(filename)

            with self.assertRaises(FaultException) as cm:
                response_handler = bootstrap._get_handler(handler_name)
            returned_exception = cm.exception
            self.assertEqual(
                self.FaultExceptionMatcher(
                    "Syntax error in",
                    "Runtime.UserCodeSyntaxError",
                    ".*File.*\\.py.*Line 1.*",
                ),
                returned_exception,
            )

    def test_get_event_handler_missing_error(self):
        importlib.invalidate_caches()
        with tempfile.NamedTemporaryFile(
            suffix=".py", dir=".", delete=False
        ) as tmp_file:
            tmp_file.write(b"def wrong_handler_name():\n\tprint('hello')")
            tmp_file.flush()

            filename_w_ext = os.path.basename(tmp_file.name)
            filename, _ = os.path.splitext(filename_w_ext)
            handler_name = "{}.my_handler".format(filename)
            with self.assertRaises(FaultException) as cm:
                response_handler = bootstrap._get_handler(handler_name)
            returned_exception = cm.exception
            self.assertEqual(
                self.FaultExceptionMatcher(
                    "Handler 'my_handler' missing on module '{}'".format(filename),
                    "Runtime.HandlerNotFound",
                ),
                returned_exception,
            )

    def test_get_event_handler_slash(self):
        importlib.invalidate_caches()
        handler_name = "tests/test_handler_with_slash/test_handler.my_handler"
        response_handler = bootstrap._get_handler(handler_name)
        response_handler()

    def test_get_event_handler_build_in_conflict(self):
        with self.assertRaises(FaultException) as cm:
            response_handler = bootstrap._get_handler("sys.hello")
        returned_exception = cm.exception
        self.assertEqual(
            self.FaultExceptionMatcher(
                "Cannot use built-in module sys as a handler module",
                "Runtime.BuiltInModuleConflict",
            ),
            returned_exception,
        )

    def test_get_event_handler_doesnt_throw_build_in_module_name_slash(self):
        response_handler = bootstrap._get_handler(
            "tests/test_built_in_module_name/sys.my_handler"
        )
        response_handler()

    def test_get_event_handler_doent_throw_build_in_module_name(self):
        response_handler = bootstrap._get_handler(
            "tests.test_built_in_module_name.sys.my_handler"
        )
        response_handler()


class TestContentType(unittest.TestCase):
    def setUp(self):
        self.lambda_runtime = Mock()
        self.lambda_runtime.marshaller = LambdaMarshaller()

    def test_application_json(self):
        bootstrap.handle_event_request(
            lambda_runtime_client=self.lambda_runtime,
            request_handler=lambda event, ctx: {"response": event["msg"]},
            invoke_id="invoke-id",
            event_body=b'{"msg":"foo"}',
            content_type="application/json",
            client_context_json=None,
            cognito_identity_json=None,
            invoked_function_arn="invocation-arn",
            epoch_deadline_time_in_ms=1415836801003,
            log_sink=bootstrap.StandardLogSink(),
        )

        self.lambda_runtime.post_invocation_result.assert_called_once_with(
            "invoke-id", '{"response": "foo"}', "application/json"
        )

    def test_binary_request_binary_response(self):
        event_body = b"\x89PNG\r\n\x1a\n\x00\x00\x00"
        bootstrap.handle_event_request(
            lambda_runtime_client=self.lambda_runtime,
            request_handler=lambda event, ctx: event,
            invoke_id="invoke-id",
            event_body=event_body,
            content_type="image/png",
            client_context_json=None,
            cognito_identity_json=None,
            invoked_function_arn="invocation-arn",
            epoch_deadline_time_in_ms=1415836801003,
            log_sink=bootstrap.StandardLogSink(),
        )

        self.lambda_runtime.post_invocation_result.assert_called_once_with(
            "invoke-id", event_body, "application/unknown"
        )

    def test_json_request_binary_response(self):
        binary_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00"
        bootstrap.handle_event_request(
            lambda_runtime_client=self.lambda_runtime,
            request_handler=lambda event, ctx: binary_data,
            invoke_id="invoke-id",
            event_body=b'{"msg":"ignored"}',
            content_type="application/json",
            client_context_json=None,
            cognito_identity_json=None,
            invoked_function_arn="invocation-arn",
            epoch_deadline_time_in_ms=1415836801003,
            log_sink=bootstrap.StandardLogSink(),
        )

        self.lambda_runtime.post_invocation_result.assert_called_once_with(
            "invoke-id", binary_data, "application/unknown"
        )

    def test_binary_with_application_json(self):
        bootstrap.handle_event_request(
            lambda_runtime_client=self.lambda_runtime,
            request_handler=lambda event, ctx: event,
            invoke_id="invoke-id",
            event_body=b"\x89PNG\r\n\x1a\n\x00\x00\x00",
            content_type="application/json",
            client_context_json=None,
            cognito_identity_json=None,
            invoked_function_arn="invocation-arn",
            epoch_deadline_time_in_ms=1415836801003,
            log_sink=bootstrap.StandardLogSink(),
        )

        self.lambda_runtime.post_invocation_result.assert_not_called()
        self.lambda_runtime.post_invocation_error.assert_called_once()

        (
            invoke_id,
            error_result,
            xray_fault,
        ), _ = self.lambda_runtime.post_invocation_error.call_args
        error_dict = json.loads(error_result)

        self.assertEqual("invoke-id", invoke_id)
        self.assertEqual("Runtime.UnmarshalError", error_dict["errorType"])


class TestLogError(unittest.TestCase):
    @patch("sys.stdout", new_callable=StringIO)
    def test_log_error_standard_log_sink(self, mock_stdout):
        err_to_log = bootstrap.make_error("Error message", "ErrorType", None)
        bootstrap.log_error(err_to_log, bootstrap.StandardLogSink())

        expected_logged_error = (
            "[ERROR] ErrorType: Error message\rTraceback (most recent call last):\n"
        )
        self.assertEqual(mock_stdout.getvalue(), expected_logged_error)

    def test_log_error_framed_log_sink(self):
        with NamedTemporaryFile() as temp_file:
            before = int(time.time_ns() / 1000)
            with bootstrap.FramedTelemetryLogSink(
                os.open(temp_file.name, os.O_CREAT | os.O_RDWR)
            ) as log_sink:
                err_to_log = bootstrap.make_error("Error message", "ErrorType", None)
                bootstrap.log_error(err_to_log, log_sink)
            after = int(time.time_ns() / 1000)

            expected_logged_error = (
                "[ERROR] ErrorType: Error message\nTraceback (most recent call last):"
            )

            with open(temp_file.name, "rb") as f:
                content = f.read()

                frame_type = int.from_bytes(content[:4], "big")
                self.assertEqual(frame_type, 0xA55A0017)

                length = int.from_bytes(content[4:8], "big")
                self.assertEqual(length, len(expected_logged_error.encode("utf8")))

                timestamp = int.from_bytes(content[8:16], "big")
                self.assertTrue(before <= timestamp)
                self.assertTrue(timestamp <= after)

                actual_message = content[16:].decode()
                self.assertEqual(actual_message, expected_logged_error)

    @patch("sys.stdout", new_callable=StringIO)
    def test_log_error_indentation_standard_log_sink(self, mock_stdout):
        err_to_log = bootstrap.make_error(
            "Error message", "ErrorType", ["  line1  ", "  line2  ", "  "]
        )
        bootstrap.log_error(err_to_log, bootstrap.StandardLogSink())

        expected_logged_error = (
            "[ERROR] ErrorType: Error message\rTraceback (most recent call last):"
            "\r\xa0\xa0line1  \r\xa0\xa0line2  \r\xa0\xa0\n"
        )
        self.assertEqual(mock_stdout.getvalue(), expected_logged_error)

    def test_log_error_indentation_framed_log_sink(self):
        with NamedTemporaryFile() as temp_file:
            before = int(time.time_ns() / 1000)
            with bootstrap.FramedTelemetryLogSink(
                os.open(temp_file.name, os.O_CREAT | os.O_RDWR)
            ) as log_sink:
                err_to_log = bootstrap.make_error(
                    "Error message", "ErrorType", ["  line1  ", "  line2  ", "  "]
                )
                bootstrap.log_error(err_to_log, log_sink)
            after = int(time.time_ns() / 1000)

            expected_logged_error = (
                "[ERROR] ErrorType: Error message\nTraceback (most recent call last):"
                "\n\xa0\xa0line1  \n\xa0\xa0line2  \n\xa0\xa0"
            )

            with open(temp_file.name, "rb") as f:
                content = f.read()

                frame_type = int.from_bytes(content[:4], "big")
                self.assertEqual(frame_type, 0xA55A0017)

                length = int.from_bytes(content[4:8], "big")
                self.assertEqual(length, len(expected_logged_error.encode("utf8")))

                timestamp = int.from_bytes(content[8:16], "big")
                self.assertTrue(before <= timestamp)
                self.assertTrue(timestamp <= after)

                actual_message = content[16:].decode()
                self.assertEqual(actual_message, expected_logged_error)

    @patch("sys.stdout", new_callable=StringIO)
    def test_log_error_empty_stacktrace_line_standard_log_sink(self, mock_stdout):
        err_to_log = bootstrap.make_error(
            "Error message", "ErrorType", ["line1", "", "line2"]
        )
        bootstrap.log_error(err_to_log, bootstrap.StandardLogSink())

        expected_logged_error = "[ERROR] ErrorType: Error message\rTraceback (most recent call last):\rline1\r\rline2\n"
        self.assertEqual(mock_stdout.getvalue(), expected_logged_error)

    def test_log_error_empty_stacktrace_line_framed_log_sink(self):
        with NamedTemporaryFile() as temp_file:
            before = int(time.time_ns() / 1000)
            with bootstrap.FramedTelemetryLogSink(
                os.open(temp_file.name, os.O_CREAT | os.O_RDWR)
            ) as log_sink:
                err_to_log = bootstrap.make_error(
                    "Error message", "ErrorType", ["line1", "", "line2"]
                )
                bootstrap.log_error(err_to_log, log_sink)
            after = int(time.time_ns() / 1000)

            expected_logged_error = (
                "[ERROR] ErrorType: Error message\nTraceback "
                "(most recent call last):\nline1\n\nline2"
            )

            with open(temp_file.name, "rb") as f:
                content = f.read()

                frame_type = int.from_bytes(content[:4], "big")
                self.assertEqual(frame_type, 0xA55A0017)

                length = int.from_bytes(content[4:8], "big")
                self.assertEqual(length, len(expected_logged_error))

                timestamp = int.from_bytes(content[8:16], "big")
                self.assertTrue(before <= timestamp)
                self.assertTrue(timestamp <= after)

                actual_message = content[16:].decode()
                self.assertEqual(actual_message, expected_logged_error)

    # Just to ensure we are not logging the requestId from error response, just sending in the response
    def test_log_error_invokeId_line_framed_log_sink(self):
        with NamedTemporaryFile() as temp_file:
            before = int(time.time_ns() / 1000)
            with bootstrap.FramedTelemetryLogSink(
                os.open(temp_file.name, os.O_CREAT | os.O_RDWR)
            ) as log_sink:
                err_to_log = bootstrap.make_error(
                    "Error message",
                    "ErrorType",
                    ["line1", "", "line2"],
                    "testrequestId",
                )
                bootstrap.log_error(err_to_log, log_sink)
            after = int(time.time_ns() / 1000)

            expected_logged_error = (
                "[ERROR] ErrorType: Error message\nTraceback "
                "(most recent call last):\nline1\n\nline2"
            )

            with open(temp_file.name, "rb") as f:
                content = f.read()

                frame_type = int.from_bytes(content[:4], "big")
                self.assertEqual(frame_type, 0xA55A0017)

                length = int.from_bytes(content[4:8], "big")
                self.assertEqual(length, len(expected_logged_error))

                timestamp = int.from_bytes(content[8:16], "big")
                self.assertTrue(before <= timestamp)
                self.assertTrue(timestamp <= after)

                actual_message = content[16:].decode()
                self.assertEqual(actual_message, expected_logged_error)


class TestUnbuffered(unittest.TestCase):
    def test_write(self):
        mock_stream = MagicMock()
        unbuffered = bootstrap.Unbuffered(mock_stream)

        unbuffered.write("YOLO!")

        mock_stream.write.assert_called_once_with("YOLO!")
        mock_stream.flush.assert_called_once()

    def test_writelines(self):
        mock_stream = MagicMock()
        unbuffered = bootstrap.Unbuffered(mock_stream)

        unbuffered.writelines(["YOLO!"])

        mock_stream.writelines.assert_called_once_with(["YOLO!"])
        mock_stream.flush.assert_called_once()


class TestLogSink(unittest.TestCase):
    @patch("sys.stdout", new_callable=StringIO)
    def test_create_unbuffered_log_sinks(self, mock_stdout):
        if "_LAMBDA_TELEMETRY_LOG_FD" in os.environ:
            del os.environ["_LAMBDA_TELEMETRY_LOG_FD"]

        actual = bootstrap.create_log_sink()

        self.assertIsInstance(actual, bootstrap.StandardLogSink)
        actual.log("log")
        self.assertEqual(mock_stdout.getvalue(), "log")

    def test_create_framed_telemetry_log_sinks(self):
        fd = 3
        os.environ["_LAMBDA_TELEMETRY_LOG_FD"] = "3"

        actual = bootstrap.create_log_sink()

        self.assertIsInstance(actual, bootstrap.FramedTelemetryLogSink)
        self.assertEqual(actual.fd, fd)
        self.assertFalse("_LAMBDA_TELEMETRY_LOG_FD" in os.environ)

    def test_single_frame(self):
        with NamedTemporaryFile() as temp_file:
            message = "hello world\nsomething on a new line!\n"
            before = int(time.time_ns() / 1000)
            with bootstrap.FramedTelemetryLogSink(
                os.open(temp_file.name, os.O_CREAT | os.O_RDWR)
            ) as ls:
                ls.log(message)
            after = int(time.time_ns() / 1000)
            with open(temp_file.name, "rb") as f:
                content = f.read()

                frame_type = int.from_bytes(content[:4], "big")
                self.assertEqual(frame_type, 0xA55A0003)

                length = int.from_bytes(content[4:8], "big")
                self.assertEqual(length, len(message))

                timestamp = int.from_bytes(content[8:16], "big")
                self.assertTrue(before <= timestamp)
                self.assertTrue(timestamp <= after)

                actual_message = content[16:].decode()
                self.assertEqual(actual_message, message)

    def test_multiple_frame(self):
        with NamedTemporaryFile() as temp_file:
            first_message = "hello world\nsomething on a new line!"
            second_message = "hello again\nhere's another message\n"

            before = int(time.time_ns() / 1000)
            with bootstrap.FramedTelemetryLogSink(
                os.open(temp_file.name, os.O_CREAT | os.O_RDWR)
            ) as ls:
                ls.log(first_message)
                ls.log(second_message)
            after = int(time.time_ns() / 1000)

            with open(temp_file.name, "rb") as f:
                content = f.read()
                pos = 0
                for message in [first_message, second_message]:
                    frame_type = int.from_bytes(content[pos : pos + 4], "big")
                    self.assertEqual(frame_type, 0xA55A0003)
                    pos += 4

                    length = int.from_bytes(content[pos : pos + 4], "big")
                    self.assertEqual(length, len(message))
                    pos += 4

                    timestamp = int.from_bytes(content[pos : pos + 8], "big")
                    self.assertTrue(before <= timestamp)
                    self.assertTrue(timestamp <= after)
                    pos += 8

                    actual_message = content[pos : pos + len(message)].decode()
                    self.assertEqual(actual_message, message)
                    pos += len(message)

                self.assertEqual(content[pos:], b"")


class TestLoggingSetup(unittest.TestCase):
    def test_log_level(self) -> None:
        test_cases = [
            (LogFormat.JSON, "TRACE", logging.DEBUG),
            (LogFormat.JSON, "DEBUG", logging.DEBUG),
            (LogFormat.JSON, "INFO", logging.INFO),
            (LogFormat.JSON, "WARN", logging.WARNING),
            (LogFormat.JSON, "ERROR", logging.ERROR),
            (LogFormat.JSON, "FATAL", logging.CRITICAL),
            (LogFormat.TEXT, "TRACE", logging.DEBUG),
            (LogFormat.TEXT, "DEBUG", logging.DEBUG),
            (LogFormat.TEXT, "INFO", logging.INFO),
            (LogFormat.TEXT, "WARN", logging.WARN),
            (LogFormat.TEXT, "ERROR", logging.ERROR),
            (LogFormat.TEXT, "FATAL", logging.CRITICAL),
            ("Unknown format", "INFO", logging.INFO),
            # if level is unknown fall back to default
            (LogFormat.JSON, "Unknown level", logging.NOTSET),
        ]
        for fmt, log_level, expected_level in test_cases:
            with self.subTest():
                # Drop previous setup
                logging.getLogger().handlers.clear()
                logging.getLogger().level = logging.NOTSET

                bootstrap._setup_logging(
                    fmt,
                    _get_log_level_from_env_var(log_level),
                    bootstrap.StandardLogSink(),
                )

                self.assertEqual(expected_level, logging.getLogger().level)


class TestLambdaLoggerHandlerSetup(unittest.TestCase):
    @classmethod
    def tearDownClass(cls):
        importlib.reload(bootstrap)
        logging.getLogger().handlers.clear()
        logging.getLogger().level = logging.NOTSET

    def test_handler_setup(self, *_):
        test_cases = [
            (62, 0xA55A0003, 46, {}),
            (133, 0xA55A001A, 117, {"AWS_LAMBDA_LOG_FORMAT": "JSON"}),
            (62, 0xA55A001B, 46, {"AWS_LAMBDA_LOG_LEVEL": "INFO"}),
        ]

        for total_length, header, message_length, env_vars in test_cases:
            with patch.dict(
                os.environ, env_vars, clear=True
            ), NamedTemporaryFile() as temp_file:
                importlib.reload(bootstrap)
                logging.getLogger().handlers.clear()
                logging.getLogger().level = logging.NOTSET

                before = int(time.time_ns() / 1000)
                with bootstrap.FramedTelemetryLogSink(
                    os.open(temp_file.name, os.O_CREAT | os.O_RDWR)
                ) as ls:
                    bootstrap._setup_logging(
                        bootstrap._AWS_LAMBDA_LOG_FORMAT,
                        bootstrap._AWS_LAMBDA_LOG_LEVEL,
                        ls,
                    )
                    logger = logging.getLogger()
                    logger.critical("critical")
                after = int(time.time_ns() / 1000)

                content = open(temp_file.name, "rb").read()
                self.assertEqual(len(content), total_length)

                pos = 0
                frame_type = int.from_bytes(content[pos : pos + 4], "big")
                self.assertEqual(frame_type, header)
                pos += 4

                length = int.from_bytes(content[pos : pos + 4], "big")
                self.assertEqual(length, message_length)
                pos += 4

                timestamp = int.from_bytes(content[pos : pos + 8], "big")
                self.assertTrue(before <= timestamp)
                self.assertTrue(timestamp <= after)


class TestLogging(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        logging.getLogger().handlers.clear()
        logging.getLogger().level = logging.NOTSET
        bootstrap._setup_logging(
            LogFormat.from_str("JSON"), "INFO", bootstrap.StandardLogSink()
        )

    @patch("sys.stderr", new_callable=StringIO)
    def test_json_formatter(self, mock_stderr):
        logger = logging.getLogger("a.b")

        test_cases = [
            (
                logging.ERROR,
                "TEST 1",
                {
                    "level": "ERROR",
                    "logger": "a.b",
                    "message": "TEST 1",
                    "requestId": "",
                },
            ),
            (
                logging.ERROR,
                "test \nwith \nnew \nlines",
                {
                    "level": "ERROR",
                    "logger": "a.b",
                    "message": "test \nwith \nnew \nlines",
                    "requestId": "",
                },
            ),
            (
                logging.CRITICAL,
                "TEST CRITICAL",
                {
                    "level": "CRITICAL",
                    "logger": "a.b",
                    "message": "TEST CRITICAL",
                    "requestId": "",
                },
            ),
        ]
        for level, msg, expected in test_cases:
            with self.subTest(msg):
                with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
                    logger.log(level, msg)

                    data = json.loads(mock_stdout.getvalue())
                    data.pop("timestamp")
                    self.assertEqual(
                        data,
                        expected,
                    )
        self.assertEqual(mock_stderr.getvalue(), "")

    @patch("sys.stdout", new_callable=StringIO)
    @patch("sys.stderr", new_callable=StringIO)
    def test_exception(self, mock_stderr, mock_stdout):
        try:
            raise ValueError("error message")
        except ValueError:
            logging.getLogger("test.logger").exception("test exception")

        exception_log = json.loads(mock_stdout.getvalue())
        self.assertIn("location", exception_log)
        self.assertIn("stackTrace", exception_log)
        exception_log.pop("timestamp")
        exception_log.pop("location")
        stack_trace = exception_log.pop("stackTrace")

        self.assertEqual(len(stack_trace), 1)

        self.assertEqual(
            exception_log,
            {
                "errorMessage": "error message",
                "errorType": "ValueError",
                "level": "ERROR",
                "logger": "test.logger",
                "message": "test exception",
                "requestId": "",
            },
        )

        self.assertEqual(mock_stderr.getvalue(), "")

    @patch("sys.stdout", new_callable=StringIO)
    @patch("sys.stderr", new_callable=StringIO)
    def test_log_level(self, mock_stderr, mock_stdout):
        logger = logging.getLogger("test.logger")

        logger.debug("debug message")
        logger.info("info message")

        data = json.loads(mock_stdout.getvalue())
        data.pop("timestamp")

        self.assertEqual(
            data,
            {
                "level": "INFO",
                "logger": "test.logger",
                "message": "info message",
                "requestId": "",
            },
        )
        self.assertEqual(mock_stderr.getvalue(), "")

    @patch("sys.stdout", new_callable=StringIO)
    @patch("sys.stderr", new_callable=StringIO)
    def test_set_log_level_manually(self, mock_stderr, mock_stdout):
        logger = logging.getLogger("test.logger")

        # Changing log level after `bootstrap.setup_logging`
        logging.getLogger().setLevel(logging.CRITICAL)

        logger.debug("debug message")
        logger.info("info message")
        logger.warning("warning message")
        logger.error("error message")
        logger.critical("critical message")

        data = json.loads(mock_stdout.getvalue())
        data.pop("timestamp")

        self.assertEqual(
            data,
            {
                "level": "CRITICAL",
                "logger": "test.logger",
                "message": "critical message",
                "requestId": "",
            },
        )
        self.assertEqual(mock_stderr.getvalue(), "")

    @patch("sys.stdout", new_callable=StringIO)
    @patch("sys.stderr", new_callable=StringIO)
    def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout):
        # Changing log level after `bootstrap.setup_logging`
        logging.config.dictConfig(
            {
                "version": 1,
                "disable_existing_loggers": False,
                "formatters": {"simple": {"format": "%(levelname)-8s - %(message)s"}},
                "handlers": {
                    "stdout": {
                        "class": "logging.StreamHandler",
                        "formatter": "simple",
                    },
                },
                "root": {
                    "level": "CRITICAL",
                    "handlers": [
                        "stdout",
                    ],
                },
            }
        )

        logger = logging.getLogger("test.logger")
        logger.debug("debug message")
        logger.info("info message")
        logger.warning("warning message")
        logger.error("error message")
        logger.critical("critical message")

        data = mock_stderr.getvalue()
        self.assertEqual(
            data,
            "CRITICAL - critical message\n",
        )
        self.assertEqual(mock_stdout.getvalue(), "")


class TestBootstrapModule(unittest.TestCase):
    @patch("awslambdaric.bootstrap.LambdaRuntimeClient")
    def test_run(self, mock_runtime_client):
        expected_app_root = "/tmp/test/app_root"
        expected_handler = "app.my_test_handler"
        expected_lambda_runtime_api_addr = "test_addr"

        mock_event_request = MagicMock()
        mock_event_request.x_amzn_trace_id = "123"

        mock_runtime_client.return_value.wait_next_invocation.side_effect = [
            mock_event_request,
            MagicMock(),
        ]

        with self.assertRaises(SystemExit) as cm:
            bootstrap.run(
                expected_app_root, expected_handler, expected_lambda_runtime_api_addr
            )

        self.assertEqual(cm.exception.code, 1)

    @patch(
        "awslambdaric.bootstrap.LambdaLoggerHandler",
        Mock(side_effect=Exception("Boom!")),
    )
    @patch("awslambdaric.bootstrap.build_fault_result")
    @patch("awslambdaric.bootstrap.log_error", MagicMock())
    @patch("awslambdaric.bootstrap.LambdaRuntimeClient", MagicMock())
    @patch("awslambdaric.bootstrap.sys")
    def test_run_exception(self, mock_sys, mock_build_fault_result):
        class TestException(Exception):
            pass

        expected_app_root = "/tmp/test/app_root"
        expected_handler = "app.my_test_handler"
        expected_lambda_runtime_api_addr = "test_addr"

        mock_build_fault_result.return_value = {}
        mock_sys.exit.side_effect = TestException("Boom!")

        with self.assertRaises(TestException):
            bootstrap.run(
                expected_app_root, expected_handler, expected_lambda_runtime_api_addr
            )

        mock_sys.exit.assert_called_once_with(1)


class TestOnInitComplete(unittest.TestCase):
    def tearDown(self):
        # We are accessing private filed for cleaning up
        snapshot_restore_py._before_snapshot_registry = []
        snapshot_restore_py._after_restore_registry = []

    # We are using ANY over here as the main thing we want to test is teh errorType propogation and stack trace generation
    error_result = {
        "errorMessage": "This is a Dummy type error",
        "errorType": "TypeError",
        "requestId": "",
        "stackTrace": ANY,
    }

    def raise_type_error(self):
        raise TypeError("This is a Dummy type error")

    @patch("awslambdaric.bootstrap.LambdaRuntimeClient")
    def test_before_snapshot_exception(self, mock_runtime_client):
        snapshot_restore_py.register_before_snapshot(self.raise_type_error)

        with self.assertRaises(SystemExit) as cm:
            bootstrap.on_init_complete(
                mock_runtime_client, log_sink=bootstrap.StandardLogSink()
            )

        self.assertEqual(cm.exception.code, 64)
        mock_runtime_client.post_init_error.assert_called_once_with(
            self.error_result,
            FaultException.BEFORE_SNAPSHOT_ERROR,
        )

    @patch("awslambdaric.bootstrap.LambdaRuntimeClient")
    def test_after_restore_exception(self, mock_runtime_client):
        snapshot_restore_py.register_after_restore(self.raise_type_error)

        with self.assertRaises(SystemExit) as cm:
            bootstrap.on_init_complete(
                mock_runtime_client, log_sink=bootstrap.StandardLogSink()
            )

        self.assertEqual(cm.exception.code, 65)
        mock_runtime_client.restore_next.assert_called_once()
        mock_runtime_client.report_restore_error.assert_called_once_with(
            self.error_result
        )


if __name__ == "__main__":
    unittest.main()