from collections import Counter
import threading
import itertools as it
from collections import defaultdict

from mcpp.parse import get_identifiers
from mcpp.queries import Q_ARGLIST, Q_IDENTIFIER, Q_FUNCTION, Q_PARAMETER, \
    Q_POINTER_EXPR, Q_ASSIGNMENT_EXPR, Q_BINARY_EXPR, Q_UPDATE_EXPR, Q_SUBSCRIPT_EXPR, \
    Q_FIELD_EXPR, Q_CALL_NAME, Q_IF_STMT, Q_SWITCH_STMT, Q_DO_STMT, Q_WHILE_STMT, \
    Q_FOR_STMT, Q_FOR_RANGE_STMT, Q_CONDITION, Q_IF_WITHOUT_ELSE, Q_POINTER_IDENTIFIER


def v1(root, sitter, lang, calls=None):
    """
    V1: number of parameter variables
    """
    sitter.add_queries({
        "Q_FUNCTION": Q_FUNCTION,
        "Q_PARAMETER": Q_PARAMETER,
    })
    functions = sitter.captures("Q_FUNCTION", root, lang).get("function", [])
    if len(functions) == 0:
        return {"V1": 0}
    function = functions[0]
    params = sitter.captures("Q_PARAMETER", function, lang).get("param", [])
    return {
        "V1": len(params)
    }


def v2(root, sitter, lang, calls=None):
    """
    V2: number of variables as parameters for callee functions
    """
    sitter.add_queries({
        "Q_ARGLIST": Q_ARGLIST
    })

    vars_in_calls = []
    for arg_list in sitter.captures("Q_ARGLIST", root, lang).get("args", []):
        variables = get_identifiers(sitter, arg_list, lang, filter=calls)
        vars_in_calls.extend(variables)

    return {
        "V2": len(set(vars_in_calls))
    }


def v3_v4_v5(root, sitter, lang, calls=None):
    """
    V3: number of pointer arithmetic
    V4: number of variables involved in pointer arithmetics
    V5: max pointer arithmetic a variable is involved in
    """
    sitter.add_queries({
        "Q_BINARY_EXPR": Q_BINARY_EXPR,
        "Q_UPDATE_EXPR": Q_UPDATE_EXPR,
        "Q_SUBSCRIPT_EXPR": Q_SUBSCRIPT_EXPR,
        "Q_ASSIGNMENT_EXPR": Q_ASSIGNMENT_EXPR,
        "Q_POINTER_EXPR": Q_POINTER_EXPR,
        "Q_FIELD_EXPR": Q_FIELD_EXPR,
        "Q_IDENTIFIER": Q_IDENTIFIER,
        "Q_POINTER_IDENTIFIER": Q_POINTER_IDENTIFIER,
    })
    assignment_operators = [
        "+=", "-=", "*=", "/=", "|=", "&=", "^=", "<<=", ">>=", "%="
    ]
    pointer_operators = ["*"]

    # Get a list of all pointer identifiers
    ptr_identifiers = sitter.captures("Q_POINTER_IDENTIFIER", root, lang).get("identifier", [])
    ptr_identifier_names = set(ptr_identifier.text.decode() for ptr_identifier in ptr_identifiers)

    # Get a list of all identifiers involved in update expressions
    update_exprs = sitter.captures("Q_UPDATE_EXPR", root, lang).get("expr", [])
    binary_exprs = sitter.captures("Q_BINARY_EXPR", root, lang).get("expr", [])

    # Calculate the number of pointer aithmetic
    v3_pointer_arith = 0
    v4_pointer_airth_identifiers = []
    pointer_arith_per_identifier = defaultdict(lambda: 0)
    
    # No. of update and binary expressions with pointers involved
    for expr in it.chain(update_exprs, binary_exprs):
        identifiers = sitter.captures("Q_IDENTIFIER", expr, lang).get("variable", [])
        identifier_names = set(identifier.text.decode() for identifier in identifiers)
        if len(identifier_names & ptr_identifier_names) > 0:
            v3_pointer_arith += 1
            v4_pointer_airth_identifiers += identifiers
            for identifier_name in identifier_names:
                pointer_arith_per_identifier[identifier_name] += 1


    # No. of subscription expressions
    subscript_exprs = sitter.captures("Q_SUBSCRIPT_EXPR", root, lang).get("expr", [])
    v3_pointer_arith += len(subscript_exprs)
    for expr in subscript_exprs:
        identifiers = sitter.captures("Q_IDENTIFIER", expr, lang).get("variable", [])
        identifier_names = set(identifier.text.decode() for identifier in identifiers)
        for identifier_name in identifier_names:
            pointer_arith_per_identifier[identifier_name] += 1
        v4_pointer_airth_identifiers += identifiers

    # No. of assignment expression where the left hand side is a pointer
    assignment_exprs = sitter.captures("Q_ASSIGNMENT_EXPR", root, lang).get("expr", [])
    for expr in assignment_exprs:
        if expr.child_by_field_name("operator").text.decode() not in assignment_operators:
            continue
        identifiers = sitter.captures("Q_IDENTIFIER", expr.child_by_field_name("left"), lang).get("variable", [])
        identifier_names = set(identifier.text.decode() for identifier in identifiers)
        if len(identifier_names & ptr_identifier_names) > 0:
            v3_pointer_arith += 1
            right_hand_side_identifiers = sitter.captures("Q_IDENTIFIER", expr.child_by_field_name("right"), lang).get("variable", [])
            right_hand_side_identifier_names = set(node.text.decode() for node in right_hand_side_identifiers)
            v4_pointer_airth_identifiers += identifiers
            v4_pointer_airth_identifiers += right_hand_side_identifiers
            for identifier_name in identifier_names.union(right_hand_side_identifier_names):
                pointer_arith_per_identifier[identifier_name] += 1

    # No. of pointer dereferences with the *ptr syntax
    pointer_exprs = sitter.captures("Q_POINTER_EXPR", root, lang).get("pointer", [])
    for expr in pointer_exprs:
        if expr.child_by_field_name("operator").text.decode() not in pointer_operators:
            continue
        identifiers = sitter.captures("Q_IDENTIFIER", expr, lang).get("variable", [])
        identifier_names = set(identifier.text.decode() for identifier in identifiers)
        v3_pointer_arith += 1
        v4_pointer_airth_identifiers += identifiers
        for identifier_name in identifier_names:
            pointer_arith_per_identifier[identifier_name] += 1

    # No. of field expressions (ptr->field)
    field_exprs = sitter.captures("Q_FIELD_EXPR", root, lang).get("expr", [])
    v3_pointer_arith += len(field_exprs)
    for expr in field_exprs:
        identifiers = sitter.captures("Q_IDENTIFIER", expr, lang).get("variable", [])
        identifier_names = set(identifier.text.decode() for identifier in identifiers)
        for identifier_name in identifier_names:
            pointer_arith_per_identifier[identifier_name] += 1

    # Calculate V5
    max_pointer_arith_identifier = max(pointer_arith_per_identifier, key=pointer_arith_per_identifier.get, default=None)
    if max_pointer_arith_identifier != None:
        v5_max_pointer_arith_var = pointer_arith_per_identifier[max_pointer_arith_identifier]
    else:
        v5_max_pointer_arith_var = 0

    return {
        "V3": v3_pointer_arith,
        "V4": len(set(v4_pointer_airth_identifiers)),
        "V5": v5_max_pointer_arith_var,
    }


def v5(root, sitter, lang, calls=None):
    """
    V5: maximum number of pointer arithmetic operations a variable is involved in
    """
    sitter.add_queries({
        "Q_BINARY_EXPR": Q_BINARY_EXPR,
        "Q_ASSIGNMENT_EXPR": Q_ASSIGNMENT_EXPR,
        "Q_CALL_NAME": Q_CALL_NAME
    })
    arith_ops = [
        "+", "++", "+=",
        "-", "--", "-=",
        "*", "*=",
        "/", "/="
    ]

    var_count = Counter()
    candidates = sitter.captures("Q_BINARY_EXPR", root, lang).get("expr", []) + sitter.captures("Q_ASSIGNMENT_EXPR", root, lang).get("expr", [])
    for node in candidates:
        if len(node.children) != 3:
            continue
        op_text = node.children[1].text.decode()
        if any(arith in op_text for arith in arith_ops):
            variables = get_identifiers(sitter, node, lang, filter=calls)
            var_count.update(variables)
    if len(var_count) > 0:
        max_count = var_count.most_common(1)[0][1]
    else:
        max_count = 0
    return {
        "V5": max_count
    }


def v6_v7(root, sitter, lang, calls=None):
    """
    V6: number of nested control structures
    V7: maximum level of control nesting
    """
    queries = {
        "Q_IF_STMT": Q_IF_STMT,
        "Q_SWITCH_STMT": Q_SWITCH_STMT,
        "Q_DO_STMT": Q_DO_STMT,
        "Q_WHILE_STMT": Q_WHILE_STMT,
        "Q_FOR_STMT": Q_FOR_STMT,
    }
    if lang == "cpp":
        queries.update({"Q_FOR_RANGE_STMT": Q_FOR_RANGE_STMT})
        
    sitter.add_queries(queries)

    nested_controls = []
    max_nesting_level = 0
    for q in queries.keys():
        for node in sitter.captures(q, root, lang).get("stmt", []):
            nesting_level = _control_nesting_level(node, lang)
            if nesting_level > 0:
                nested_controls.append(node)
            max_nesting_level = max(max_nesting_level, nesting_level)

    return {
        "V6": len(nested_controls),
        "V7": max_nesting_level
    }


def _control_nesting_level(node, lang):
    control_types = [
        "if_statement",
        "switch_statement",
        "do_statement",
        "while_statement",
        "for_statement",
    ]
    if lang == "cpp":
        control_types.append("for_range_loop")
        
    parent = node.parent
    num_control_ancestors = 0
    while parent is not None:
        if parent.type in control_types:
            num_control_ancestors += 1
        parent = parent.parent
    return num_control_ancestors


def v8(root, sitter, lang, calls=None):
    """
    V8: maximum number of control-dependent control structures
    """
    queries = {
        "Q_IF_STMT": Q_IF_STMT,
        "Q_SWITCH_STMT": Q_SWITCH_STMT,
        "Q_DO_STMT": Q_DO_STMT,
        "Q_WHILE_STMT": Q_WHILE_STMT,
        "Q_FOR_STMT": Q_FOR_STMT,
        #"Q_CONDITION": Q_CONDITION,
    }
    if lang == "cpp":
        queries.update({"Q_FOR_RANGE_STMT": Q_FOR_RANGE_STMT})
        
    sitter.add_queries(queries)

    # count dependent controls under another control: key = start_byte of parent in function
    control_dependent_controls = Counter()
    threads = []
    thread_lock = threading.Lock()
    for q in queries.keys():
        t = threading.Thread(target=_v8_single_query,
                             args=(root, sitter, lang, calls, q,
                                   control_dependent_controls, thread_lock))
        t.start()
        threads.append(t)
    for t in threads:
        t.join()

    v8_val = max([0] + list(control_dependent_controls.values()))

    return {
        "V8": 0 if v8_val == 0 else v8_val + 1,
    }


def _v8_single_query(root, sitter, lang, calls, query, control_dependent_controls, thread_lock):
    tag = "condition" if "Q_CONDITION" in query else "stmt"
    for node in sitter.captures(query, root, lang).get(tag, []):
        parents = _traverse_parent_controls(node, lang)
        if len(parents) > 0:
            with thread_lock:
                control_dependent_controls[parents[-1].start_byte] += 1


def _traverse_parent_controls(node, lang):
    """ Climb up the AST and emit all control nodes. """
    control_types = [
        "if_statement",
        "switch_statement",
        "do_statement",
        "while_statement",
        "for_statement",
    ]
    if lang == "cpp":
        control_types.append("for_range_loop")
        
    parent_controls = []
    parent = node.parent
    while parent is not None:
        if parent.type in control_types:
            parent_controls.append(parent)
        parent = parent.parent
    return parent_controls


def v9(root, sitter, lang, calls=None):
    """
    V9: maximum number of data-dependent control structures
    """
    sitter.add_queries({
        "Q_IDENTIFIER": Q_IDENTIFIER,
        "Q_CONDITION": Q_CONDITION,
    })

    # Count the number of depend control structures for each identifier
    dependend_ctrl_structures_count = defaultdict(lambda: 0)

    conditions = sitter.captures("Q_CONDITION", root, lang).get("condition", [])
    for condition in conditions:
        identifiers = sitter.captures("Q_IDENTIFIER", condition, lang).get("variable", [])
        identifier_names = set(identifier.text.decode() for identifier in identifiers)
        
        for identifier_name in identifier_names:
            dependend_ctrl_structures_count[identifier_name] += 1

    if len(dependend_ctrl_structures_count) > 0:
        max_key = max(dependend_ctrl_structures_count, key=dependend_ctrl_structures_count.get)
        max_val = dependend_ctrl_structures_count[max_key]
    else:
        max_val = 0

    return {
        "V9": max_val,
    }


def v10(root, sitter, lang, calls=None):
    """
    V10: number of if statements without else
    """
    sitter.add_queries({
        "Q_IF_WITHOUT_ELSE": Q_IF_WITHOUT_ELSE
    })

    if_without_else = sitter.captures("Q_IF_WITHOUT_ELSE", root, lang).get("stmt", [])
    return {
        "V10": len(if_without_else)
    }


def v11(root, sitter, lang, calls=None):
    """
    V11: number of variables in control structures (in each predicate)
    """
    sitter.add_queries({
        "Q_CONDITION": Q_CONDITION
    })

    num_controlled_vars = 0
    conditions = sitter.captures("Q_CONDITION", root, lang).get("condition", [])
    identifiers = set()
    for condition in conditions:
        identifiers |= set(get_identifiers(sitter, condition, lang, filter=calls))

    return {
        "V11": len(identifiers),
    }
