import asyncio
import aiohttp
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup  # Make sure to install beautifulsoup4 via pip
import markdownify
from ai_kit.shared_console import shared_console

def normalize_domain(domain: str) -> str:
    """Remove www. prefix and any trailing dots from domain."""
    return domain.lower().replace('www.', '').rstrip('.')

class Crawler:
    def __init__(self, seed_url, base_domain, max_workers=10, max_urls=100, throttle_limit=5, max_retries=2):
        """
        Initialize the crawler with the seed URL, base domain, number of worker tasks,
        the maximum number of URLs to crawl, the concurrency throttling limit,
        and the maximum retries for failed HTTP requests.

        - seed_url: Starting URL for the crawl.
        - base_domain: Domain to restrict to (subdomains considered if they match).
        - max_workers: Number of concurrent worker tasks.
        - max_urls: Maximum number of pages/URLs to crawl.
        - throttle_limit: Maximum number of concurrent HTTP fetches.
        - max_retries: Maximum number of retry attempts for a failed request.
        """
        # Strip scheme if present so that base_domain is just the domain
        if "://" in base_domain:
            base_domain = urlparse(base_domain).netloc
        self.base_domain = normalize_domain(base_domain)
        
        # Parse and normalize the seed URL's domain to match base domain normalization
        seed_parsed = urlparse(seed_url)
        normalized_domain = normalize_domain(seed_parsed.netloc)
        # Rebuild the seed URL with the normalized domain
        self.seed_url = seed_parsed._replace(netloc=normalized_domain).geturl()
        
        self.max_workers = max_workers
        self.max_urls = max_urls
        self.max_retries = max_retries
        self.visited = set()
        self.queue = asyncio.Queue()
        self.sitemap = {}  # {url: {"url": url, "html": html, "children": [child_url, ...]}}

        # Semaphore to throttle concurrent HTTP requests.
        self.semaphore = asyncio.Semaphore(throttle_limit)
        
        # Store the seed path
        self.seed_path = seed_parsed.path or "/"
    
    async def fetch(self, session, url):
        """Fetch the HTML content of the URL asynchronously with retry mechanism and throttling."""
        for attempt in range(self.max_retries):
            try:
                async with self.semaphore:
                    async with session.get(url, timeout=10) as response:
                        if response.status == 200:
                            return await response.text()
                        else:
                            print(f"Non-200 status code {response.status} for URL: {url}")
            except Exception as e:
                print(f"Error fetching {url}: {e} (attempt {attempt + 1}/{self.max_retries})")
            if attempt < self.max_retries:
                await asyncio.sleep(1)  # brief delay before retrying
        return None

    def extract_links(self, base_url, html):
        """
        Extract all href links from <a> tags in the HTML that belong to the same base domain and
        are subpaths of the seed URL's path (i.e. its children).

        Resolves relative URLs and filters links so that only those whose hostname ends with
        the base domain are considered, and whose path is either exactly the seed path or starts with
        the seed path followed by a '/'.
        """
        links = set()
        soup = BeautifulSoup(html, "html.parser")
        for tag in soup.find_all("a"):
            href = tag.get("href")
            if not href:
                continue
            # Clean up the href (remove extraneous quotes or spaces)
            href = href.strip(' "\'')
            absolute = urljoin(base_url, href)
            parsed = urlparse(absolute)
            if parsed.scheme in ['http', 'https'] and parsed.netloc:
                normalized_domain = normalize_domain(parsed.netloc)
                if normalized_domain == self.base_domain:
                    allowed = self.seed_path
                    allowed_prefix = allowed if allowed.endswith('/') else allowed + '/'
                    # Only follow links whose path equals the seed path or is a descendant.
                    if parsed.path == allowed or parsed.path.startswith(allowed_prefix):
                        normalized = f"{parsed.scheme}://{parsed.netloc}{parsed.path}"
                        links.add(normalized)
        return links

    async def worker(self, name, session):
        """Worker task that continuously processes URLs from the queue."""
        while True:
            url = await self.queue.get()
            # If we've reached the maximum URLs limit, skip processing further.
            if len(self.visited) >= self.max_urls:
                self.queue.task_done()
                continue

            if url in self.visited:
                self.queue.task_done()
                continue

            print(f"{name} crawling: {url}")
            self.visited.add(url)
            html = await self.fetch(session, url)
            if html:
                # Save the page info in the sitemap.
                self.sitemap[url] = {
                    "url": url,
                    "html": html,
                    "children": [],
                    "markdown": self.parse(html),
                }
                links = self.extract_links(url, html)
                for link in links:
                    # Update the children list for the current page.
                    if link not in self.sitemap[url]["children"]:
                        self.sitemap[url]["children"].append(link)
                    # Enqueue the child page if not visited and within the max_urls limit.
                    if link not in self.visited and len(self.visited) < self.max_urls:
                        await self.queue.put(link)
            self.queue.task_done()

    async def crawl(self):
        """
        Start the asynchronous crawling process.
        
        It enqueues the seed URL, starts the worker tasks, waits until the queue is empty,
        cancels the worker tasks and returns the sitemap object containing each page's URL,
        HTML content, and its child links.
        """
        await self.queue.put(self.seed_url)
        async with aiohttp.ClientSession() as session:
            tasks = [
                asyncio.create_task(self.worker(f"Worker-{i}", session))
                for i in range(self.max_workers)
            ]
            await self.queue.join()    
            for task in tasks:
                task.cancel()
        return self.sitemap
    
    def parse(self, html: str) -> str:
        if not html:
            return ""
        markdown = markdownify.markdownify(
            html,
            strip=["footer", "header", "a", "svg"],
            bullets="*",
        )
        return markdown.replace("\n\n", "").strip()
    
    def pretty_print_sitemap(self, current_url=None, indent=0, visited=None):
        """
        Recursively prints the sitemap starting from the current_url.
        If current_url is None, it starts from the seed URL.
        This function will show each URL along with a snippet of its HTML content.
        """
        if visited is None:
            visited = set()
        if current_url is None:
            current_url = self.seed_url
        if current_url in visited:
            return
        visited.add(current_url)
        node = self.sitemap.get(current_url)
        if not node:
            return
        indent_str = " " * indent
        # Print the URL and a snippet of HTML (first 100 characters).
        shared_console.print(f"[bold blue]URL: {node['url']}[/bold blue]")
        html_snippet = node['html'][:100].replace("\n", " ") if node['html'] else ""
        shared_console.print(f"[yellow]HTML: {html_snippet}...[/yellow]")
        shared_console.print("\n")
        # Recursively print the children.
        for child in node["children"]:
            self.pretty_print_sitemap(child, indent=indent + 2, visited=visited)
