"""A harvester for Bader charge analysis."""

import logging
from pathlib import Path
from typing import TypedDict

from pymatgen.io.vasp.outputs import Vasprun

logger = logging.getLogger(__name__)


class BaderAnalysis(TypedDict):
    """Bader change analysis data."""

    min_dist: list[float]
    charge: list[float]
    atomic_volume: list[float]
    vacuum_charge: float
    vacuum_volume: float
    charge_transfer: list[float] | None = None
    nelectrons: float


def parse_acf(src: str | Path) -> list[dict]:
    """Parse Bader output file ACF.dat."""
    with Path(src, "ACF.dat").open(mode="r", encoding="us-ascii") as file:
        lines = file.readlines()

    headers = ("x", "y", "z", "charge", "min_dist", "atomic_volume")
    data: dict[str, list[float]] = {
        "x": [],
        "y": [],
        "z": [],
        "charge": [],
        "min_dist": [],
        "atomic_volume": [],
    }
    # Skip header lines
    lines.pop(0)
    lines.pop(0)

    should_continue = True

    while should_continue:
        line = lines.pop(0).strip()
        if line.startswith("-"):
            should_continue = False
        else:
            vals = map(float, line.split()[1:])
            for k, v in dict(zip(headers, vals, strict=False)).items():
                data[k].append(v)

    for line in lines:
        tokens = line.strip().split(":")
        if tokens[0] == "VACUUM CHARGE":
            data["vacuum_charge"] = float(tokens[1])
        elif tokens[0] == "VACUUM VOLUME":
            data["vacuum_volume"] = float(tokens[1])
        elif tokens[0] == "NUMBER OF ELECTRONS":
            data["nelectrons"] = float(tokens[1])

    return data


def load_bader_results(
    src: str | Path,
) -> BaderAnalysis:
    """Load Bader charge analysis data."""
    logger.info(f"Loading Bader data for {src!s}")
    data = parse_acf(src=src)
    try:
        vasprun = Vasprun(Path(src, "vasprun.xml"))
        symbol_to_elect: dict[str, float] = {}

        for potcar in vasprun.get_potcars(True):
            symbol_to_elect[potcar.symbol] = potcar.nelectrons

        nelects = [symbol_to_elect[s] for s in vasprun.atomic_symbols]
        data["charge_transfer"] = [
            c - nelects[i] for i, c in enumerate(data["charge"])
        ]

        logger.info(f"Successfully loaded Bader data for {src!s}")

    except FileNotFoundError:
        logger.warning(f"Unable to load Bader data for {src!s}")
        raise

    return BaderAnalysis(**data)
