#!/usr/bin/env python3

import argparse
import functools
import json
import pathlib
import resource
import subprocess
import sys
import time
import typing
from contextlib import nullcontext  # type: ignore

from faker import Faker

from streamson import SimpleMatcher, extract_fd

BUFF_SIZE = 1024 * 1024  # use 1MB buffer


def ijson_generic(parse_function: typing.Callable, src_path: str, dst_path: typing.Optional[str] = None) -> int:
    count = 0
    with (pathlib.Path(dst_path).open("w") if dst_path else nullcontext()) as outputf:
        with pathlib.Path(src_path).open() as inputf:
            for prefix, event, value in parse_function(inputf, buf_size=BUFF_SIZE):
                if event == "string" and (prefix == "users.item.name" or prefix == "groups.item.name"):
                    count += 1
                    if outputf:
                        outputf.write(f"{value}\n")
    return count


def streamson(
    src_path: str,
    dst_path: typing.Optional[str] = None,
) -> int:
    matcher = SimpleMatcher('{"users"}[]{"name"}') | SimpleMatcher('{"groups"}[]{"name"}')
    handler = None

    count = 0

    with (pathlib.Path(dst_path).open("wb") if dst_path else nullcontext()) as outputf:
        with pathlib.Path(src_path).open("rb") as inputf:
            for output in extract_fd(inputf, [(matcher, handler)], require_path=False):
                if output is None:
                    count += 1
                    if outputf:
                        outputf.write(b"\n")
                elif outputf:
                    path, data = output
                    if data:
                        outputf.write(data[1:-1])  # convert "Name" -> Name

    return count


streamson = functools.partial(streamson)


def std_generic(load_function: typing.Callable, src_path: str, dst_path: typing.Optional[str] = None) -> int:
    with pathlib.Path(src_path).open() as f:
        data = load_function(f)
    if dst_path:
        with pathlib.Path(dst_path).open("w") as f:
            for user in data["users"]:
                f.write(f"{user['name']}\n")
            for group in data["groups"]:
                f.write(f"{group['name']}\n")

    return len(data["users"]) + len(data["groups"])


stdlib = functools.partial(std_generic, json.load)


STRATEGIES: typing.Dict[str, typing.Callable] = {
    "streamson": streamson,
    "stdlib": stdlib,
}


try:
    import hyperjson

    STRATEGIES["hyperjson"] = functools.partial(std_generic, hyperjson.load)
except ImportError:
    pass

try:
    from ijson.backends.python import parse as ijson_python_parse
    from ijson.backends.yajl2 import parse as ijson_yajl2_parse
    from ijson.backends.yajl2_c import parse as ijson_yajl2_c_parse
    from ijson.backends.yajl2_cffi import parse as ijson_yajl2_cffi_parse

    STRATEGIES["ijson-python"] = functools.partial(ijson_generic, ijson_python_parse)
    STRATEGIES["ijson-yajl2"] = functools.partial(ijson_generic, ijson_yajl2_parse)
    STRATEGIES["ijson-yajl2_c"] = functools.partial(ijson_generic, ijson_yajl2_c_parse)
    STRATEGIES["ijson-yajl2_cffi"] = functools.partial(ijson_generic, ijson_yajl2_cffi_parse)

except ImportError:
    pass


def generate(output_path: str, user_count: int, group_count: int):
    Faker.seed(0)
    faker = Faker()
    with pathlib.Path(output_path).open("w") as f:
        f.write("{\n")
        f.write('\t"users": [\n')
        users = []
        for i in range(user_count):
            users.append(f'\t\t{json.dumps({"id": i, "name": faker.name()})}')
        f.write(",\n".join(users))
        f.write("\n\t],\n")
        f.write('\t"groups": [\n')
        groups = []
        for i in range(group_count):
            groups.append(f'\t\t{json.dumps({"id": i, "name": faker.name()})}')
        f.write(",\n".join(groups))
        f.write("\t]\n")
        f.write("}\n")


def memory(strategy: str, input_path: str):
    START_STEP = 1000
    args = [__file__, "_bench", "--strategy", strategy, "--input-file", input_path]

    def call(limit) -> int:
        retval = subprocess.call(
            args + ["--memory-limit", str(limit)],
            stderr=subprocess.DEVNULL,
            stdout=subprocess.DEVNULL,
        )
        print(f"{limit}MB - {retval == 0}", file=sys.stderr)
        return retval

    # first find passing max
    high = START_STEP
    while call(high) != 0:
        high += START_STEP

    # cut interval
    low = 0

    while low < high:
        current = (low + high) // 2
        if call(current) == 0:
            high = current
            last_passed = current
        else:
            low = current + 1

    print(f"{last_passed}MB")


def times(strategy: str, input_path: str, attempts: int):
    args = [__file__, "_bench", "--strategy", strategy, "--input-file", input_path]

    def call(attempt_number: int) -> float:
        output = subprocess.check_output(args)

        res = float(output.strip())
        print(f"#{attempt_number}: {res:.5f}s", file=sys.stderr)
        return res

    results = []
    for i in range(attempts):
        results.append(call(i))

    print(f"Average of {len(results)} attempts: {sum(results) / len(results):.5f}s")


def bench(
    strategy: str,
    input_path: str,
    output_path: typing.Optional[str],
    memory_limit: int = 0,
    check_total: typing.Optional[int] = None,
):
    if memory_limit:
        hard = memory_limit * 1024 * 1024
        soft = hard
        resource.setrlimit(resource.RLIMIT_DATA, (soft, hard))
    start_time = time.monotonic()
    total = STRATEGIES[strategy](input_path, output_path)
    total_time = time.monotonic() - start_time

    if check_total and total != check_total:
        print(f"Failed to check {total} != {check_total}")
        sys.exit(1)

    print(total_time)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(prog="streamson-bench")

    subparsers = parser.add_subparsers(help="Commands", dest="cmd")
    subparsers.required = True

    gen_parser = subparsers.add_parser("generate", help="generate input data")
    gen_parser.add_argument("-u", "--users", help="user count", default=1000, type=int)
    gen_parser.add_argument("-g", "--groups", help="group count", default=1000, type=int)
    gen_parser.add_argument("-o", "--output-file", help="output-file", required=True)

    mem_parser = subparsers.add_parser("memory", help="try to figure out the max memory usage")
    mem_parser.add_argument("-i", "--input-file", help="input json file", required=True)
    mem_parser.add_argument("-s", "--strategy", help="parse strategy", choices=STRATEGIES.keys(), required=True)

    time_parser = subparsers.add_parser(
        "time",
        help="performs several measurements and calculates the average",
    )
    time_parser.add_argument("-i", "--input-file", help="input json file", required=True)
    time_parser.add_argument("-s", "--strategy", help="parse strategy", choices=STRATEGIES.keys(), required=True)
    time_parser.add_argument(
        "-a",
        "--attempts",
        help="number of attempts (default=10)",
        type=int,
        default=10,
    )

    bench_parser = subparsers.add_parser("_bench", help="performs a signle benchmark (outputs consumed time)")
    bench_parser.add_argument("-s", "--strategy", help="parse strategy", choices=STRATEGIES.keys(), required=True)
    bench_parser.add_argument("-i", "--input-file", help="input json file", required=True)
    bench_parser.add_argument("-o", "--output-file", help="output csv file", default=None)
    bench_parser.add_argument("-l", "--memory-limit", help="memory limit (in MB)", default=0, type=int)
    bench_parser.add_argument(
        "-c",
        "--check-total",
        help="check total group and user count",
        default=None,
        type=int,
    )

    options = parser.parse_args()

    CMDS = {
        "generate": lambda: generate(options.output_file, options.users, options.groups),
        "memory": lambda: memory(options.strategy, options.input_file),
        "time": lambda: times(options.strategy, options.input_file, options.attempts),
        "_bench": lambda: bench(
            options.strategy,
            options.input_file,
            options.output_file,
            options.memory_limit,
            options.check_total,
        ),
    }

    CMDS[options.cmd]()
