from copy import deepcopy
from pathlib import Path

from ase import Atoms
from ase.build import molecule
from ase.calculators.calculator import Calculator
import ase.io
import pytest

from autojob.tasks.calculation import Calculation
from autojob.tasks.scan import BondScan
from autojob.tasks.scan import BondScanInputs
from autojob.tasks.scan import BondScanParams


@pytest.fixture(name="input_atoms")
def fixture_input_atoms(atoms_filename: str) -> Atoms | None:
    atoms = molecule("CO")
    atoms.info["structure"] = "".join(sorted(str(atoms.symbols)))
    atoms.info["filename"] = atoms_filename
    return atoms


@pytest.fixture(name="bond_scan_params")
def fixture_bond_scan_params() -> list[BondScanParams]:
    return [BondScanParams(0, 1)]


@pytest.fixture(name="traj_template")
def fixture_traj_template() -> str:
    return "scan_{}_{}.traj"


@pytest.fixture(name="bond_scan_inputs")
def fixture_bond_scan_inputs(
    traj_template: str,
    bond_scan_params: list[BondScanParams],
) -> BondScanInputs:
    return BondScanInputs(
        traj_template=traj_template,
        bond_scan_params=bond_scan_params,
    )


@pytest.fixture(name="bond_scan")
def fixture_bond_scan(
    calculation: Calculation, bond_scan_inputs: BondScanInputs
) -> BondScan:
    bond_scan: BondScan = BondScan(
        bond_scan_inputs=bond_scan_inputs,
        **calculation.model_dump(),
    )
    task_class = bond_scan.__class__.__name__.lower()
    bond_scan.task_metadata.task_class = task_class
    return bond_scan


@pytest.fixture(name="scan_bond")
def fixture_scan_bond(
    output_atoms: Atoms,
    calculator: Calculator,
    task_directory: Path,
    bond_scan: BondScan,
) -> None:
    output_atoms.calc = calculator
    calculator.atoms = output_atoms
    traj_template = bond_scan.bond_scan_inputs.traj_template
    scan_step = bond_scan.bond_scan_inputs.scan_step
    images = {}

    for bond in bond_scan.bond_scan_inputs.bond_scan_params:
        filename = Path(task_directory, traj_template.format(bond.a0, bond.a1))
        images[(bond.a0, bond.a1)] = []
        min_bond, max_bond = bond.bond_lims
        d = min_bond

        while d <= max_bond:
            image = output_atoms.copy()
            image.set_distance(
                bond.a0,
                bond.a1,
                distance=d,
                fix=bond.fix,
                mask=bond.mask,
                indices=bond.indices,
            )
            # Setting Atoms.calc sometimes also sets Calculator.atoms, so this
            # ensures that images can be saved with their calcuator results
            image.calc = deepcopy(calculator)
            image.calc.calculate(image)
            images[(bond.a0, bond.a1)].append(image)

            d += scan_step

        ase.io.write(filename, images[(bond.a0, bond.a1)])
