# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

from agentkit.client import BaseAgentkitClient
from agentkit.utils import get_logger

# 导入自动生成的类型
from agentkit.memory.memory_all_types import (
    CreateMemoryCollectionRequest,
    CreateMemoryCollectionResponse,
    UpdateMemoryCollectionRequest,
    UpdateMemoryCollectionResponse,
    DeleteMemoryCollectionRequest,
    DeleteMemoryCollectionResponse,
    ListMemoryCollectionsRequest,
    ListMemoryCollectionsResponse,
    AddMemoryCollectionRequest,
    AddMemoryCollectionResponse,
    GetMemoryCollectionRequest,
    GetMemoryCollectionResponse,
    GetMemoryConnectionInfoRequest,
    GetMemoryConnectionInfoResponse,
)

logger = get_logger(__name__)


class AgentkitMemory(BaseAgentkitClient):
    """AgentKit Memory Collection Management Service"""
    
    # Define all API actions for this service
    API_ACTIONS: Dict[str, str] = {
        "CreateMemoryCollection": "CreateMemoryCollection",
        "UpdateMemoryCollection": "UpdateMemoryCollection",
        "DeleteMemoryCollection": "DeleteMemoryCollection",
        "ListMemoryCollections": "ListMemoryCollections",
        "AddMemoryCollection": "AddMemoryCollection",
        "GetMemoryCollection": "GetMemoryCollection",
        "GetMemoryConnectionInfo": "GetMemoryConnectionInfo",
    }
    
    def __init__(
        self,
        access_key: str = "",
        secret_key: str = "",
        region: str = "",
        session_token: str = "",
    ) -> None:
        super().__init__(
            access_key=access_key,
            secret_key=secret_key,
            region=region,
            session_token=session_token,
            service_name="memory",
        )

    def create_memory_collection(
        self, request: CreateMemoryCollectionRequest
    ) -> CreateMemoryCollectionResponse:
        """Create a new memory collection."""
        return self._invoke_api(
            api_action="CreateMemoryCollection",
            request=request,
            response_type=CreateMemoryCollectionResponse,
        )

    def update_memory_collection(
        self, request: UpdateMemoryCollectionRequest
    ) -> UpdateMemoryCollectionResponse:
        """Update an existing memory collection."""
        return self._invoke_api(
            api_action="UpdateMemoryCollection",
            request=request,
            response_type=UpdateMemoryCollectionResponse,
        )

    def delete_memory_collection(
        self, request: DeleteMemoryCollectionRequest
    ) -> DeleteMemoryCollectionResponse:
        """Delete a memory collection."""
        return self._invoke_api(
            api_action="DeleteMemoryCollection",
            request=request,
            response_type=DeleteMemoryCollectionResponse,
        )

    def list_memory_collections(
        self, request: ListMemoryCollectionsRequest
    ) -> ListMemoryCollectionsResponse:
        """List all memory collections."""
        return self._invoke_api(
            api_action="ListMemoryCollections",
            request=request,
            response_type=ListMemoryCollectionsResponse,
        )

    def add_memory_collection(
        self, request: AddMemoryCollectionRequest
    ) -> AddMemoryCollectionResponse:
        """Add memory collections from external providers."""
        return self._invoke_api(
            api_action="AddMemoryCollection",
            request=request,
            response_type=AddMemoryCollectionResponse,
        )

    def get_memory_collection(
        self, request: GetMemoryCollectionRequest
    ) -> GetMemoryCollectionResponse:
        """Get detailed information about a specific memory collection."""
        return self._invoke_api(
            api_action="GetMemoryCollection",
            request=request,
            response_type=GetMemoryCollectionResponse,
        )

    def get_memory_connection_info(
        self, request: GetMemoryConnectionInfoRequest
    ) -> GetMemoryConnectionInfoResponse:
        """Get connection information for a specific memory collection."""
        return self._invoke_api(
            api_action="GetMemoryConnectionInfo",
            request=request,
            response_type=GetMemoryConnectionInfoResponse,
        )


if __name__ == "__main__":
    # list memory
    memory = AgentkitMemory()
    req = ListMemoryCollectionsRequest(page_number=1, page_size=10)
    res = memory.list_memory_collections(req)
    print(res)