#!/var/lang/bin/python -u
# ruff: noqa
import sys

sys.path.append("/var/runtime")
import json
import os
import typing

import boto3
import urllib3

if typing.TYPE_CHECKING:
    import mypy_boto3_s3

http = urllib3.PoolManager()

LAMBDA_EXTENSION_NAME = "myext1"
AWS_LAMBDA_RUNTIME_API = os.environ.get("AWS_LAMBDA_RUNTIME_API")


if "AWS_ENDPOINT_URL" in os.environ:
    s3_resource: "mypy_boto3_s3.S3ServiceResource" = boto3.resource(
        "s3",
        endpoint_url=os.environ["AWS_ENDPOINT_URL"],
    )
else:
    s3_resource: "mypy_boto3_s3.S3ServiceResource" = boto3.resource("s3")


bucket_name = os.environ["TEST_BUCKET"]
bucket = s3_resource.Bucket(bucket_name)


def log_event(name: str, event: typing.Union[dict, list]):
    recorded_event = json.dumps(event, indent=2)
    bucket.put_object(Key=name, Body=recorded_event)


request_counter = 0


def send_and_log_request(method: str, url: str, headers: dict, body: typing.Optional[str] = None):
    global request_counter
    request_counter += 1

    kwargs = {}
    if body:
        kwargs["body"] = body.encode("utf-8")

    r = http.request(method, url, headers=headers, **kwargs)

    log_event(
        f"extensions_api_call_{request_counter}",
        {
            "request": {
                "method": method,
                "url": url,
                "headers": headers,
                **({"body": body} if body else {}),
            },
            "response": {
                "headers": dict(r.headers),
                "body": r.data.decode(encoding="utf-8"),
                "statuscode": r.status,
            },
        },
    )
    return r


try:
    log_event("env", list(os.environ.keys()))

    # 1. /register
    response = send_and_log_request(
        "POST",
        f"http://{AWS_LAMBDA_RUNTIME_API}/2020-01-01/extension/register",
        headers={"Lambda-Extension-Name": LAMBDA_EXTENSION_NAME},
        body=json.dumps({"events": ["INVOKE", "SHUTDOWN"]}),
    )
    ext_id = response.headers.get("Lambda-Extension-Identifier")
    next_headers = {"Lambda-Extension-Identifier": ext_id}

    try:
        # /next loop
        shutdown_received = False
        while not shutdown_received:
            response = send_and_log_request(
                "GET",
                f"http://{AWS_LAMBDA_RUNTIME_API}/2020-01-01/extension/event/next",
                headers=next_headers,
            )
            data = json.loads(response.data.decode("utf-8"))
            shutdown_received = data["eventType"] == "SHUTDOWN"
        log_event("loop_exited", {})
    except Exception as e:
        # TODO: manually trigger the exception here for testing purposes
        log_event("exception_caught", {"exc": e})
        exit_content = {
            "errorType": "Extension.ErrorEnum",
            "errorMessage": "something went wrong",
            "stackTrace": [],
        }
        error_headers = {
            **next_headers,
            "Lambda-Extension-Function-Error-Type": "Extension.ErrorEnum",
        }
        send_and_log_request(
            "POST",
            f"http://{AWS_LAMBDA_RUNTIME_API}/2020-01-01/extension/exit/error",
            headers=next_headers,
            body=json.dumps(exit_content),
        )

except Exception as e:
    log_event("fatal_exception_caught", {"exc": e})
