"""
OatDB Python Client - Improved Version

This client provides a fluent API for building and executing OatDB queries.
"""

import requests
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Union
from hashlib import sha256
from pldag import PLDAG, CompilationSetting
from graphlib import TopologicalSorter
from abc import ABC, abstractmethod
import json

def from_sub_data(sub_data: Dict[str, dict]) -> PLDAG:
    """Convert OatDB sub response to PLDAG."""
    dag = PLDAG(compilation_setting=CompilationSetting.ON_DEMAND)
    structured = {}

    for node_id, node_data in sub_data.items():
        k, v = list(node_data.items())[0]
        if k == "Primitive":
            structured[node_id] = {
                "type": "primitive",
                "bound": complex(v[0], v[1])
            }
        elif k == "Composite":
            structured[node_id] = {
                "type": "composite",
                "coefficients": dict(v['coefficients']),
                "bias": v['bias'][0]
            }

    # Build dependency graph
    dependencies = {k: set() for k in structured.keys()}
    for node_id, node_data in filter(lambda x: x[1]['type'] == "composite", structured.items()):
        for dep in node_data['coefficients'].keys():
            dependencies[node_id].add(dep)

    # Add nodes in topological order
    for node_id in TopologicalSorter(dependencies).static_order():
        node_data = structured[node_id]
        if node_data['type'] == "primitive":
            dag.set_primitive(id=node_id, bound=node_data['bound'])
        elif node_data['type'] == "composite":
            dag.set_gelineq(
                coefficients=node_data['coefficients'],
                bias=node_data['bias']
            )

    dag.compile()
    return dag


# Type alias for values that can be either a direct value or a reference to a FunctionCall
RefOrValue = Union['FunctionCall', str, int, List, Dict]


@dataclass
class FunctionCall(ABC):
    """Base class for all OatDB function calls."""
    fn: str
    args: dict

    def _serialize_for_hash(self, obj):
        """Recursively serialize objects for hashing, converting FunctionCalls to their out IDs."""
        if isinstance(obj, FunctionCall):
            return f"<ref:{obj.out}>"
        elif isinstance(obj, dict):
            return {k: self._serialize_for_hash(v) for k, v in obj.items()}
        elif isinstance(obj, (list, tuple)):
            return [self._serialize_for_hash(v) for v in obj]
        elif isinstance(obj, complex):
            return [obj.real, obj.imag]
        else:
            return obj

    @property
    def out(self) -> str:
        """Generate a unique output identifier for this call."""
        # Serialize args recursively, handling nested FunctionCalls
        serialized = self._serialize_for_hash({"fn": self.fn, "args": self.args})
        hash_data = json.dumps(serialized, sort_keys=True)
        return sha256(hash_data.encode()).hexdigest()

    @abstractmethod
    def to_json(self) -> dict:
        """Convert this call to JSON format for the OatDB API."""
        pass

    def _ref_or_value(self, value: RefOrValue) -> Union[dict, str, int, List]:
        """Convert a value to either a reference or direct value."""
        if isinstance(value, FunctionCall):
            return {"$ref": value.out}
        return value

    def _ref_or_value_list(self, values: List[RefOrValue]) -> List[Union[dict, str]]:
        """Convert a list of values to references or direct values."""
        return [self._ref_or_value(v) for v in values]


@dataclass
class SetPrimitiveCall(FunctionCall):
    """Call to set_primitive function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "id": self._ref_or_value(self.args['id']),
                "bound": [
                    int(self.args['bound'].real),
                    int(self.args['bound'].imag)
                ]
            },
            "out": self.out
        }


@dataclass
class SetGelineqCall(FunctionCall):
    """Call to set_gelineq function."""

    def to_json(self) -> dict:
        coefficients = self.args['coefficients']

        # Handle FunctionCall reference or dict of coefficients
        if isinstance(coefficients, FunctionCall):
            coeff_value = {"$ref": coefficients.out}
        else:
            # Convert to list of {id: ..., coefficient: ...}
            coeff_value = [
                {
                    "id": self._ref_or_value(k),
                    "coefficient": v
                }
                for k, v in coefficients.items()
            ]

        return {
            "fn": self.fn,
            "args": {
                "coefficients": coeff_value,
                "bias": self._ref_or_value(self.args['bias'])
            },
            "out": self.out
        }


@dataclass
class SetThresholdOperatorCall(FunctionCall):
    """Base class for set_atleast, set_atmost, set_equal."""

    def to_json(self) -> dict:
        references = self.args['references']

        # Handle whole list reference or list of references
        if isinstance(references, FunctionCall):
            ref_value = {"$ref": references.out}
        else:
            ref_value = self._ref_or_value_list(references)

        return {
            "fn": self.fn,
            "args": {
                "references": ref_value,
                "value": self._ref_or_value(self.args['value'])
            },
            "out": self.out
        }


@dataclass
class SetLogicalOperatorCall(FunctionCall):
    """Base class for set_and, set_or, set_not, set_xor."""

    def to_json(self) -> dict:
        references = self.args['references']

        # Handle whole list reference or list of references
        if isinstance(references, FunctionCall):
            ref_value = {"$ref": references.out}
        else:
            ref_value = self._ref_or_value_list(references)

        args = {"references": ref_value}

        # Add alias if provided
        if 'alias' in self.args:
            args['alias'] = self._ref_or_value(self.args['alias'])

        return {
            "fn": self.fn,
            "args": args,
            "out": self.out
        }


@dataclass
class SetBinaryOperatorCall(FunctionCall):
    """Base class for set_imply, set_equiv."""

    def to_json(self) -> dict:
        args = {
            "lhs": self._ref_or_value(self.args['lhs']),
            "rhs": self._ref_or_value(self.args['rhs'])
        }

        # Add alias if provided
        if 'alias' in self.args:
            args['alias'] = self._ref_or_value(self.args['alias'])

        return {
            "fn": self.fn,
            "args": args,
            "out": self.out
        }


@dataclass
class SubCall(FunctionCall):
    """Call to sub function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "root": self._ref_or_value(self.args['root'])
            },
            "out": self.out
        }


@dataclass
class SetPropertyCall(FunctionCall):
    """Call to set_property function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "id": self._ref_or_value(self.args['id']),
                "property": self.args['property'],
                "value": self.args['value']
            },
            "out": self.out
        }


@dataclass
class SetPrimitivesCall(FunctionCall):
    """Call to set_primitives function."""

    def to_json(self) -> dict:
        ids = self.args['ids']
        if isinstance(ids, FunctionCall):
            ids_value = {"$ref": ids.out}
        else:
            ids_value = self._ref_or_value_list(ids)

        return {
            "fn": self.fn,
            "args": {
                "ids": ids_value,
                "bound": [
                    int(self.args['bound'].real),
                    int(self.args['bound'].imag)
                ]
            },
            "out": self.out
        }


@dataclass
class GetNodeIdsCall(FunctionCall):
    """Call to get_node_ids function."""

    def to_json(self) -> dict:
        args = {}
        if 'filter' in self.args and self.args['filter'] is not None:
            args['filter'] = self._ref_or_value(self.args['filter'])
        return {
            "fn": self.fn,
            "args": args,
            "out": self.out
        }


@dataclass
class GetNodesCall(FunctionCall):
    """Call to get_nodes function."""

    def to_json(self) -> dict:
        ids = self.args['ids']
        if isinstance(ids, FunctionCall):
            ids_value = {"$ref": ids.out}
        else:
            ids_value = self._ref_or_value_list(ids)

        return {
            "fn": self.fn,
            "args": {"ids": ids_value},
            "out": self.out
        }


@dataclass
class GetNodeCall(FunctionCall):
    """Call to get_node function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "id": self._ref_or_value(self.args['id'])
            },
            "out": self.out
        }


@dataclass
class GetPropertyValuesCall(FunctionCall):
    """Call to get_property_values function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "property": self.args['property']
            },
            "out": self.out
        }


@dataclass
class GetAliasCall(FunctionCall):
    """Call to get_alias function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "id": self._ref_or_value(self.args['id'])
            },
            "out": self.out
        }


@dataclass
class GetIdFromAliasCall(FunctionCall):
    """Call to get_id_from_alias function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "alias": self._ref_or_value(self.args['alias'])
            },
            "out": self.out
        }


@dataclass
class GetAliasesFromIdCall(FunctionCall):
    """Call to get_aliases_from_id function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "id": self._ref_or_value(self.args['id'])
            },
            "out": self.out
        }


@dataclass
class GetIdsFromAliasesCall(FunctionCall):
    """Call to get_ids_from_aliases function."""

    def to_json(self) -> dict:
        aliases = self.args['aliases']
        if isinstance(aliases, FunctionCall):
            aliases_value = {"$ref": aliases.out}
        else:
            aliases_value = self._ref_or_value_list(aliases)

        return {
            "fn": self.fn,
            "args": {"aliases": aliases_value},
            "out": self.out
        }


@dataclass
class PropagateCall(FunctionCall):
    """Call to propagate function."""

    def to_json(self) -> dict:
        assignments = self.args['assignments']
        if isinstance(assignments, FunctionCall):
            assign_value = {"$ref": assignments.out}
        else:
            assign_value = [
                {
                    "id": self._ref_or_value(a['id']),
                    "bound": a['bound']
                }
                for a in assignments
            ]

        return {
            "fn": self.fn,
            "args": {"assignments": assign_value},
            "out": self.out
        }


@dataclass
class PropagateManyCall(FunctionCall):
    """Call to propagate_many function."""

    def to_json(self) -> dict:
        many = self.args['many_assignments']
        if isinstance(many, FunctionCall):
            many_value = {"$ref": many.out}
        else:
            many_value = [
                [
                    {
                        "id": self._ref_or_value(a['id']),
                        "bound": a['bound']
                    }
                    for a in assignments
                ]
                for assignments in many
            ]

        return {
            "fn": self.fn,
            "args": {"many_assignments": many_value},
            "out": self.out
        }


@dataclass
class SubManyCall(FunctionCall):
    """Call to sub_many function."""

    def to_json(self) -> dict:
        roots = self.args['roots']
        if isinstance(roots, FunctionCall):
            roots_value = {"$ref": roots.out}
        else:
            roots_value = self._ref_or_value_list(roots)

        return {
            "fn": self.fn,
            "args": {"roots": roots_value},
            "out": self.out
        }


@dataclass
class ValidateCall(FunctionCall):
    """Call to validate function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "dag": self._ref_or_value(self.args['dag'])
            },
            "out": self.out
        }


@dataclass
class RanksCall(FunctionCall):
    """Call to ranks function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "dag": self._ref_or_value(self.args['dag'])
            },
            "out": self.out
        }


@dataclass
class DeleteNodeCall(FunctionCall):
    """Call to delete_node function."""

    def to_json(self) -> dict:
        return {
            "fn": self.fn,
            "args": {
                "id": self._ref_or_value(self.args['id'])
            },
            "out": self.out
        }


@dataclass
class DeleteSubCall(FunctionCall):
    """Call to delete_sub function."""

    def to_json(self) -> dict:
        roots = self.args['roots']
        if isinstance(roots, FunctionCall):
            roots_value = {"$ref": roots.out}
        else:
            roots_value = self._ref_or_value_list(roots)

        return {
            "fn": self.fn,
            "args": {"roots": roots_value},
            "out": self.out
        }


@dataclass
class SolveCall(FunctionCall):
    """Call to solve function."""

    def to_json(self) -> dict:
        # Handle dag
        dag = self._ref_or_value(self.args['dag'])

        # Handle objective
        objective = self.args.get('objective', [])
        if isinstance(objective, FunctionCall):
            obj_value = {"$ref": objective.out}
        else:
            obj_value = [
                {
                    "id": self._ref_or_value(coef['id']),
                    "coefficient": coef['coefficient']
                }
                for coef in objective
            ]

        # Handle assume
        assume = self.args.get('assume', [])
        if isinstance(assume, FunctionCall):
            assume_value = {"$ref": assume.out}
        else:
            assume_value = [
                {
                    "id": self._ref_or_value(assum['id']),
                    "bound": assum['bound']
                }
                for assum in assume
            ]

        return {
            "fn": self.fn,
            "args": {
                "dag": dag,
                "objective": obj_value,
                "assume": assume_value,
                "maximize": self.args.get('maximize', True)
            },
            "out": self.out
        }


@dataclass
class SolveManyCall(FunctionCall):
    """Call to solve_many function."""

    def to_json(self) -> dict:
        # Handle dag
        dag = self._ref_or_value(self.args['dag'])

        # Handle objectives (list of objectives)
        objectives = self.args.get('objectives', [])
        if isinstance(objectives, FunctionCall):
            obj_value = {"$ref": objectives.out}
        else:
            obj_value = [
                [
                    {
                        "id": self._ref_or_value(coef['id']),
                        "coefficient": coef['coefficient']
                    }
                    for coef in objective
                ]
                for objective in objectives
            ]

        # Handle assume
        assume = self.args.get('assume', [])
        if isinstance(assume, FunctionCall):
            assume_value = {"$ref": assume.out}
        else:
            assume_value = [
                {
                    "id": self._ref_or_value(assum['id']),
                    "bound": assum['bound']
                }
                for assum in assume
            ]

        return {
            "fn": self.fn,
            "args": {
                "dag": dag,
                "objectives": obj_value,
                "assume": assume_value,
                "maximize": self.args.get('maximize', True)
            },
            "out": self.out
        }


class OatDBError(Exception):
    """Base exception for OatDB client errors."""
    pass


class OatDBConnectionError(OatDBError):
    """Raised when connection to OatDB fails."""
    pass


class OatDBExecutionError(OatDBError):
    """Raised when OatDB execution fails."""

    def __init__(self, message: str, status_code: int, response: dict):
        super().__init__(message)
        self.status_code = status_code
        self.response = response


@dataclass
class OatClient:
    """
    Client for interacting with OatDB API.

    Usage:
        client = OatClient("http://localhost:7061")

        # Build query
        x = client.set_primitive("x", bound=10j)
        y = client.set_primitive("y", bound=10j)
        constraint = client.set_and([x, y])
        dag = client.sub(constraint)

        # Execute
        result = client.execute([dag.out])
        print(result[dag.out])
    """
    base_url: str
    _buffer: List[FunctionCall] = field(default_factory=list)
    timeout: int = 30
    verify_ssl: bool = True

    def health_check(self) -> bool:
        """Check if the OatDB server is healthy."""
        try:
            response = requests.get(
                f"{self.base_url}/health",
                timeout=self.timeout,
                verify=self.verify_ssl
            )
            return response.status_code == 200
        except requests.RequestException:
            return False

    def clear_buffer(self):
        """Clear all buffered calls."""
        self._buffer.clear()

    def get_buffer_size(self) -> int:
        """Get the number of buffered calls."""
        return len(self._buffer)

    def set_primitive(
        self,
        id: Union[str, FunctionCall],
        bound: complex = 1j
    ) -> SetPrimitiveCall:
        """
        Create a primitive variable.

        Args:
            id: Variable identifier or reference to another call
            bound: Variable bound as complex number (default: 0+1j means [0, 1])

        Returns:
            SetPrimitiveCall that can be referenced by other calls
        """
        call = SetPrimitiveCall(
            fn="set_primitive",
            args={"id": id, "bound": bound}
        )
        self._buffer.append(call)
        return call

    def set_gelineq(
        self,
        coefficients: Union[Dict[Union[str, FunctionCall], int], FunctionCall],
        bias: Union[int, FunctionCall],
        alias: Optional[str] = None
    ) -> SetGelineqCall:
        """
        Create a linear inequality constraint.

        Args:
            coefficients: Dict mapping variable IDs to coefficients, or reference
            bias: Bias term

        Returns:
            SetGelineqCall
        """
        call = SetGelineqCall(
            fn="set_gelineq",
            args={"coefficients": coefficients, "bias": bias, "alias": alias}
        )
        self._buffer.append(call)
        return call

    def set_atleast(
        self,
        references: Union[List[Union[str, FunctionCall]], FunctionCall],
        value: Union[int, FunctionCall],
        alias: Optional[str] = None
    ) -> SetThresholdOperatorCall:
        """Create an 'at least' constraint."""
        args = {"references": references, "value": value}
        if alias:
            args["alias"] = alias
        call = SetThresholdOperatorCall(fn="set_atleast", args=args)
        self._buffer.append(call)
        return call

    def set_atmost(
        self,
        references: Union[List[Union[str, FunctionCall]], FunctionCall],
        value: Union[int, FunctionCall],
        alias: Optional[str] = None
    ) -> SetThresholdOperatorCall:
        """Create an 'at most' constraint."""
        args = {"references": references, "value": value}
        if alias:
            args["alias"] = alias
        call = SetThresholdOperatorCall(fn="set_atmost", args=args)
        self._buffer.append(call)
        return call

    def set_equal(
        self,
        references: Union[List[Union[str, FunctionCall]], FunctionCall],
        value: Union[int, FunctionCall],
        alias: Optional[str] = None
    ) -> SetThresholdOperatorCall:
        """Create an 'equal' constraint."""
        args = {"references": references, "value": value}
        if alias:
            args["alias"] = alias
        call = SetThresholdOperatorCall(fn="set_equal", args=args)
        self._buffer.append(call)
        return call

    def set_and(
        self,
        references: Union[List[Union[str, FunctionCall]], FunctionCall],
        alias: Optional[str] = None
    ) -> SetLogicalOperatorCall:
        """Create an AND constraint."""
        args = {"references": references}
        if alias:
            args["alias"] = alias
        call = SetLogicalOperatorCall(fn="set_and", args=args)
        self._buffer.append(call)
        return call

    def set_or(
        self,
        references: Union[List[Union[str, FunctionCall]], FunctionCall],
        alias: Optional[str] = None
    ) -> SetLogicalOperatorCall:
        """Create an OR constraint."""
        args = {"references": references}
        if alias:
            args["alias"] = alias
        call = SetLogicalOperatorCall(fn="set_or", args=args)
        self._buffer.append(call)
        return call

    def set_not(
        self,
        references: Union[List[Union[str, FunctionCall]], FunctionCall],
        alias: Optional[str] = None
    ) -> SetLogicalOperatorCall:
        """Create a NOT constraint."""
        args = {"references": references}
        if alias:
            args["alias"] = alias
        call = SetLogicalOperatorCall(fn="set_not", args=args)
        self._buffer.append(call)
        return call

    def set_xor(
        self,
        references: Union[List[Union[str, FunctionCall]], FunctionCall],
        alias: Optional[str] = None
    ) -> SetLogicalOperatorCall:
        """Create an XOR constraint."""
        args = {"references": references}
        if alias:
            args["alias"] = alias
        call = SetLogicalOperatorCall(fn="set_xor", args=args)
        self._buffer.append(call)
        return call

    def set_imply(
        self,
        lhs: Union[str, FunctionCall],
        rhs: Union[str, FunctionCall],
        alias: Optional[str] = None
    ) -> SetBinaryOperatorCall:
        """Create an IMPLY constraint (lhs => rhs)."""
        args = {"lhs": lhs, "rhs": rhs}
        if alias:
            args["alias"] = alias
        call = SetBinaryOperatorCall(fn="set_imply", args=args)
        self._buffer.append(call)
        return call

    def set_equiv(
        self,
        lhs: Union[str, FunctionCall],
        rhs: Union[str, FunctionCall],
        alias: Optional[str] = None
    ) -> SetBinaryOperatorCall:
        """Create an EQUIV constraint (lhs <=> rhs)."""
        args = {"lhs": lhs, "rhs": rhs}
        if alias:
            args["alias"] = alias
        call = SetBinaryOperatorCall(fn="set_equiv", args=args)
        self._buffer.append(call)
        return call

    def get_id_from_alias(
        self,
        alias: Union[str, FunctionCall]
    ) -> GetIdFromAliasCall:
        """Get node ID from alias."""
        call = GetIdFromAliasCall(
            fn="get_id_from_alias",
            args={"alias": alias}
        )
        self._buffer.append(call)
        return call

    def sub(
        self,
        root: Union[str, FunctionCall]
    ) -> SubCall:
        """Extract a sub-DAG starting from root."""
        call = SubCall(
            fn="sub",
            args={"root": root}
        )
        self._buffer.append(call)
        return call

    def solve(
        self,
        dag: Union[dict, FunctionCall],
        objective: Optional[List[Dict[str, Any]]] = None,
        assume: Optional[List[Dict[str, Any]]] = None,
        maximize: bool = True
    ) -> SolveCall:
        """
        Solve an optimization problem.

        Args:
            dag: DAG to solve (typically from a sub() call)
            objective: List of {id: ..., coefficient: ...} dicts
            assume: List of {id: ..., bound: [min, max]} dicts
            maximize: Whether to maximize (True) or minimize (False)

        Returns:
            SolveCall
        """
        call = SolveCall(
            fn="solve",
            args={
                "dag": dag,
                "objective": objective or [],
                "assume": assume or [],
                "maximize": maximize
            }
        )
        self._buffer.append(call)
        return call

    def set_property(
        self,
        id: Union[str, FunctionCall],
        property: str,
        value: Any
    ) -> SetPropertyCall:
        """Set a property value on a node."""
        call = SetPropertyCall(
            fn="set_property",
            args={"id": id, "property": property, "value": value}
        )
        self._buffer.append(call)
        return call

    def set_primitives(
        self,
        ids: Union[List[Union[str, FunctionCall]], FunctionCall],
        bound: complex = 1j
    ) -> SetPrimitivesCall:
        """Create multiple primitive variables with the same bound."""
        call = SetPrimitivesCall(
            fn="set_primitives",
            args={"ids": ids, "bound": bound}
        )
        self._buffer.append(call)
        return call

    def get_node_ids(
        self,
        dag: Union[dict, FunctionCall]
    ) -> GetNodeIdsCall:
        """Get all node IDs from a DAG."""
        call = GetNodeIdsCall(
            fn="get_node_ids",
            args={"dag": dag}
        )
        self._buffer.append(call)
        return call

    def get_nodes(
        self,
        ids: Union[List[Union[str, FunctionCall]], FunctionCall]
    ) -> GetNodesCall:
        """Get multiple nodes by their IDs."""
        call = GetNodesCall(
            fn="get_nodes",
            args={"ids": ids}
        )
        self._buffer.append(call)
        return call

    def get_node(
        self,
        id: Union[str, FunctionCall]
    ) -> GetNodeCall:
        """Get a single node by its ID."""
        call = GetNodeCall(
            fn="get_node",
            args={"id": id}
        )
        self._buffer.append(call)
        return call

    def get_property_values(
        self,
        property: str
    ) -> GetPropertyValuesCall:
        """Get all nodes that have a specific property and their values."""
        call = GetPropertyValuesCall(
            fn="get_property_values",
            args={"property": property}
        )
        self._buffer.append(call)
        return call

    def get_alias(
        self,
        id: Union[str, FunctionCall]
    ) -> GetAliasCall:
        """Get alias for a given ID."""
        call = GetAliasCall(
            fn="get_alias",
            args={"id": id}
        )
        self._buffer.append(call)
        return call

    def get_aliases_from_id(
        self,
        id: Union[str, FunctionCall]
    ) -> GetAliasesFromIdCall:
        """Get all aliases for a given ID."""
        call = GetAliasesFromIdCall(
            fn="get_aliases_from_id",
            args={"id": id}
        )
        self._buffer.append(call)
        return call

    def get_ids_from_aliases(
        self,
        aliases: Union[List[str], FunctionCall]
    ) -> GetIdsFromAliasesCall:
        """Get IDs for multiple aliases."""
        call = GetIdsFromAliasesCall(
            fn="get_ids_from_aliases",
            args={"aliases": aliases}
        )
        self._buffer.append(call)
        return call

    def propagate(
        self,
        assignments: Union[List[Dict[str, Any]], FunctionCall]
    ) -> PropagateCall:
        """Propagate constraints given initial assignments."""
        call = PropagateCall(
            fn="propagate",
            args={"assignments": assignments}
        )
        self._buffer.append(call)
        return call

    def propagate_many(
        self,
        many_assignments: Union[List[List[Dict[str, Any]]], FunctionCall]
    ) -> PropagateManyCall:
        """Propagate constraints given multiple sets of initial assignments."""
        call = PropagateManyCall(
            fn="propagate_many",
            args={"many_assignments": many_assignments}
        )
        self._buffer.append(call)
        return call

    def sub_many(
        self,
        roots: Union[List[Union[str, FunctionCall]], FunctionCall]
    ) -> SubManyCall:
        """Extract sub-DAGs from multiple root nodes."""
        call = SubManyCall(
            fn="sub_many",
            args={"roots": roots}
        )
        self._buffer.append(call)
        return call

    def solve_many(
        self,
        dag: Union[dict, FunctionCall],
        objectives: Optional[List[List[Dict[str, Any]]]] = None,
        assume: Optional[List[Dict[str, Any]]] = None,
        maximize: bool = True
    ) -> SolveManyCall:
        """
        Solve multiple optimization problems on the same DAG.

        Args:
            dag: DAG to solve (typically from a sub() call)
            objectives: List of objectives, where each objective is a list of {id: ..., coefficient: ...} dicts
            assume: List of {id: ..., bound: [min, max]} dicts
            maximize: Whether to maximize (True) or minimize (False)

        Returns:
            SolveManyCall
        """
        call = SolveManyCall(
            fn="solve_many",
            args={
                "dag": dag,
                "objectives": objectives or [],
                "assume": assume or [],
                "maximize": maximize
            }
        )
        self._buffer.append(call)
        return call

    def validate(
        self,
        dag: Union[dict, FunctionCall]
    ) -> ValidateCall:
        """Validate a DAG structure."""
        call = ValidateCall(
            fn="validate",
            args={"dag": dag}
        )
        self._buffer.append(call)
        return call

    def ranks(
        self,
        dag: Union[dict, FunctionCall]
    ) -> RanksCall:
        """Compute topological ranks for nodes in a DAG."""
        call = RanksCall(
            fn="ranks",
            args={"dag": dag}
        )
        self._buffer.append(call)
        return call

    def delete_node(
        self,
        id: Union[str, FunctionCall]
    ) -> DeleteNodeCall:
        """Delete a node by its ID."""
        call = DeleteNodeCall(
            fn="delete_node",
            args={"id": id}
        )
        self._buffer.append(call)
        return call

    def delete_sub(
        self,
        roots: Union[List[Union[str, FunctionCall]], FunctionCall]
    ) -> DeleteSubCall:
        """Delete sub-DAGs from the given roots."""
        call = DeleteSubCall(
            fn="delete_sub",
            args={"roots": roots}
        )
        self._buffer.append(call)
        return call

    def execute(
        self,
        outputs: Optional[List[str]] = None,
        clear_buffer: bool = True
    ) -> Dict[str, Any]:
        """
        Execute all buffered calls.

        Args:
            outputs: List of output variable names to return.
                    If None, returns all outputs.
            clear_buffer: Whether to clear the buffer after execution

        Returns:
            Dict mapping output names to their values

        Raises:
            OatDBConnectionError: If connection to server fails
            OatDBExecutionError: If server returns an error
        """
        if not self._buffer:
            return {}

        payload = {
            "calls": [call.to_json() for call in self._buffer]
        }

        if outputs:
            payload["outputs"] = outputs

        try:
            response = requests.post(
                f"{self.base_url}/call",
                json=payload,
                timeout=self.timeout,
                verify=self.verify_ssl
            )

            if response.status_code != 200:
                error_data = response.json() if response.headers.get('content-type') == 'application/json' else {}
                raise OatDBExecutionError(
                    f"OatDB execution failed: {response.status_code}",
                    response.status_code,
                    error_data
                )

            results = response.json()

            if clear_buffer:
                self.clear_buffer()

            # If outputs were specified, return dict mapping output names to values
            if outputs:
                return dict(zip(outputs, results))

            # Otherwise return the raw results
            return results

        except requests.RequestException as e:
            raise OatDBConnectionError(f"Failed to connect to OatDB: {e}") from e

    def debug_payload(self) -> dict:
        """
        Get the JSON payload that would be sent without executing.
        Useful for debugging.
        """
        return {
            "calls": [call.to_json() for call in self._buffer]
        }

# Export public API
__all__ = [
    'OatClient'
]