"""涉及到量子哈密顿量的概念，可以看 [2.8节的笔记](./2-8.ipynb)"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../3-8.ipynb.

# %% auto 0
__all__ = ['PolarizationGate', 'ADQCTimeEvolution']

# %% ../../3-8.ipynb 2
import torch
from einops import einsum
from torch import nn

# %% ../../3-8.ipynb 11
from torch import Tensor
from .adqc import ADQCNet
from ..tensor_gates.modules import ParameterizedGate, SimpleGate
from typing import Set, Literal
from ..tensor_gates.functional import spin_operator


class PolarizationGate(ParameterizedGate):
    """
    A gate that applies a magnetic field to polarize a qubit.
    """

    def __init__(
        self,
        *,
        batched_input: bool,
        time_slice: float,
        target_qubit: int,
        h_directions: Set[Literal["x", "y", "z"]],
    ):
        """
        Initialize the PolarizationGate.

        Args:
            batched_input: Whether the input is batched.
            time_slice: The time slice width of the gate.
            target_qubit: The target qubit of the gate.
            h_directions: The directions of the spin operator.
        """
        assert isinstance(h_directions, set), "h_directions must be a list"
        assert 3 >= len(h_directions) > 0, "h_directions must be a non-empty set"
        assert all(direction in ["x", "y", "z"] for direction in h_directions), (
            "h_directions must contain only x, y, z"
        )
        assert time_slice > 0, "time_slice must be greater than 0"
        assert target_qubit >= 0, "target_qubit must be greater than or equal to 0"
        parameters = nn.ParameterDict(
            {
                direction: nn.Parameter(torch.randn(1), requires_grad=True)
                for direction in h_directions
            }
        )

        super().__init__(
            batched_input=batched_input,
            gate_params=parameters,
            requires_grad=True,
            target_qubit=target_qubit,
        )
        self.spin = nn.ParameterDict(
            {
                "x": nn.Parameter(spin_operator("X"), requires_grad=False),
                "y": nn.Parameter(spin_operator("Y"), requires_grad=False),
                "z": nn.Parameter(spin_operator("Z"), requires_grad=False),
            }
        )
        self.time_slice = time_slice
        self.h_directions = h_directions

    def forward(self, tensor: Tensor) -> Tensor:
        spin_matrix = 0
        for direction in self.h_directions:
            spin_matrix += self.gate_params[direction] * self.spin[direction]

        gate = torch.matrix_exp(-1j * self.time_slice * spin_matrix)
        return self.apply_gate(
            tensor=tensor,
            gate=gate,
            target_qubit=self.target_qubit,
        )


class ADQCTimeEvolution(nn.Module):
    def __init__(
        self,
        hamiltonian: torch.Tensor,
        num_qubits: int,
        time_steps: int,
        time_slice: float,
        h_directions: Set[Literal["x", "y", "z"]],
    ):
        super().__init__()
        assert hamiltonian.shape == (4, 4) or hamiltonian.shape == (2, 2, 2, 2), (
            "Hamiltonian must be a 4x4 matrix or 2x2x2x2 tensor"
        )
        if hamiltonian.shape == (2, 2, 2, 2):
            hamiltonian = hamiltonian.reshape(4, 4)
        assert num_qubits > 0, "Number of qubits must be greater than 0"
        assert time_steps > 0, "Time steps must be greater than 0"
        assert time_slice > 0, "Time slice must be greater than 0"
        U = torch.matrix_exp(-1j * time_slice * hamiltonian).reshape(2, 2, 2, 2)
        per_layer_gate_pattern = ADQCNet.calc_gate_target_qubit_positions(
            gate_pattern="brick", num_qubits=num_qubits
        )
        gates = []
        for _ in range(time_steps):
            # gates for coupling
            for position in per_layer_gate_pattern:
                gates.append(
                    SimpleGate(
                        batched_input=False,
                        gate_name="coupling",
                        target_qubit=list(position),
                        gate=U,
                        requires_grad=False,
                    )
                )
            # gates for polarization
            for qubit_idx in range(num_qubits):
                gates.append(
                    PolarizationGate(
                        batched_input=False,
                        time_slice=time_slice,
                        target_qubit=qubit_idx,
                        h_directions=h_directions,
                    )
                )

        self.net = nn.Sequential(*gates)

    def forward(self, tensor: Tensor) -> Tensor:
        return self.net(tensor)
