import warnings
warnings.simplefilter("ignore", UserWarning)

import datetime
import os
import sys
import copy
import multiprocessing as mp
import time
import zlib
import requests
import pandas as pd
import numpy as np

from .equity_curves import *
from .signals import *
from .aggregate import *
from .components import *
import itertools

def get_all_para_combination(para_dict, backtest_attribute, df_dict, sec_profile, manager_list):

    risk_free_rate = get_risk_free_rate(backtest_attribute['start_date'], backtest_attribute['end_date'])

    freq                   = backtest_attribute['freq']

    intraday = True if freq != '1D' else False
    if intraday:
        if not 'summary_mode' in backtest_attribute:
            backtest_attribute['summary_mode'] = True
        backtest_attribute['intraday']     = True
    else:
        backtest_attribute['summary_mode'] = False
        backtest_attribute['intraday']     = False

    backtest_attribute['risk_free_rate'] = risk_free_rate

    para_keys = para_dict.keys()
    para_values = para_dict.values()
    para_list = list(itertools.product(*para_values))
    para_keys_str = '|'.join(para_keys)

    print('number of combination:', len(para_list))

    all_combinations = list(itertools.product(*para_values))
    all_para_combination = {}

    for combination in all_combinations:
        para_combination = dict(zip(para_keys, combination))
        para_combination.update(backtest_attribute)
        para_combination['manager_list']   = manager_list

        concatenated_values = ''.join(map(str, para_combination.values())).encode('utf-8')
        ref_code = zlib.crc32(concatenated_values)

        code = para_combination['code']
        para_combination['df'] = df_dict[code]
        para_combination['sec_profile'] = sec_profile
        para_combination['para_keys_str'] = para_keys_str

        all_para_combination[ref_code]      = para_combination

    # attr_df = pd.DataFrame([backtest_attribute])
    # attr_df.to_parquet('attr_df.parquet')

    return all_para_combination

def generate_or_read_backtest_result(read_only, mp_mode, number_of_core, manager_list
                                     ,all_para_combination, backtest):

    if read_only and os.path.isfile('backtest_result.parquet'):
        backtest_result_df = pd.read_parquet('backtest_result.parquet')
        #backtest_result_df = backtest_result_df[backtest_result_df.index.isin(all_para_combination.keys())]

    else:
        t1 = datetime.datetime.now()
        if mp_mode:
            pool = mp.Pool(processes=number_of_core)
            pool.map(backtest, all_para_combination.items())
            pool.close()
        else:
            for para_combination_item in all_para_combination.items():
               backtest(para_combination_item)

        backtest_result_df = pd.DataFrame(list(manager_list))
        backtest_result_df = backtest_result_df.set_index('ref_code')

        if os.path.isfile('backtest_result.parquet'):
            old_backtest_result_df = pd.read_parquet('backtest_result.parquet')
            backtest_result_df = pd.concat([backtest_result_df, old_backtest_result_df])
            backtest_result_df = backtest_result_df[~backtest_result_df.index.duplicated(keep='last')]

        backtest_result_df.to_parquet('backtest_result.parquet')

        print('backtest time used:', (datetime.datetime.now() - t1).seconds, 'seconds')

    backtest_result_df = backtest_result_df[backtest_result_df.index.isin(all_para_combination.keys())]

    return backtest_result_df

def get_source_data_path(data_folder, code, freq):
    file_path = os.path.join(data_folder, code + '_' + freq + '.parquet')
    return file_path


def save_backtest_result(df, para_combination_item):

    ref_code, para_combination = para_combination_item

    equity_curve_folder    = para_combination['equity_curve_folder']
    manager_list           = para_combination['manager_list']
    risk_free_rate         = para_combination['risk_free_rate']
    default_market_price   = para_combination['default_market_price']
    intraday               = para_combination['intraday']

    total_commission = df['commission'].sum()

    df = df[[default_market_price, 'action', 'trd_side', 'equity_value']]
    df = df.rename(columns={default_market_price: 'price'})

    eqiuty_curve_save_path = os.path.join('', equity_curve_folder, f'{ref_code}.parquet')
    ##############################################################

    if default_market_price == 'close':
        agg_price = 'last'
    elif default_market_price == 'open':
        agg_price = 'first'
    ### for intraday ###
    if intraday:
        df_daily = df.resample('D').agg({'price': agg_price})
    else:
        df_daily = df

    ########### price result ###############
    df_daily['running_max'] = df_daily['price'].cummax()
    df_daily['dd_dollar'] = df_daily['running_max'] - df_daily['price']
    df_daily['dd_pct'] = df_daily['dd_dollar'] / df_daily['running_max'] * 100
    price_mdd_dollar = df_daily['dd_dollar'].max()
    price_mdd_pct = df_daily['dd_pct'].max()

    price_pct_series = df_daily['price'].pct_change().dropna()
    price_net_profit = df_daily.at[df_daily.index[-1], 'price'] - df_daily.at[df_daily.index[0], 'price']
    holding_period_day = (df_daily.index[-1].date() - df_daily.index[0].date()).days

    price_return_on_capital = price_net_profit / df_daily.at[df_daily.index[0], 'price']
    price_annualized_return = (np.sign(1 + price_return_on_capital) * np.abs(1 + price_return_on_capital)) ** (
                365 / holding_period_day) - 1
    price_annualized_std = price_pct_series.std() * math.sqrt(365)
    price_annualized_sr = (
                                      price_annualized_return - risk_free_rate / 100) / price_annualized_std if price_annualized_std != 0 else 0
    price_net_profit_to_mdd = price_net_profit / price_mdd_dollar if price_mdd_dollar != 0 else 0

    price_return_on_capital = round(100 * price_return_on_capital, 2)
    price_annualized_return = round(100 * price_annualized_return, 2)
    price_annualized_std = round(100 * price_annualized_std, 2)
    price_annualized_sr = round(price_annualized_sr, 2)
    price_net_profit_to_mdd = round(100 * price_net_profit_to_mdd, 2)

    df_count = df[(df['trd_side'] == 'BUY') | (df['trd_side'] == 'SELL_SHORT')].copy()

    if len(df_count) == 0:
        num_of_trade = 0
        equity_return_on_capital = 0
        equity_annualized_return = 0
        equity_annualized_std = 0
        equity_annualized_sr = 0
        equity_net_profit_to_mdd = np.inf

        num_of_win = 0
        num_of_trade = 0
        win_rate = 0
        yearly_stats_string = ''

        cov_return = 0
        cov_count = 0
        total_commission = 0

    elif len(df_count) > 0:

        ########### by year count, win rate and return ###############

        num_of_trade = len(df_count)
        df_count = pd.concat([df_count, df[df['trd_side'] != ''].tail(1)])
        df_count['realized_pnl'] = df_count['equity_value'] - df_count['equity_value'].shift(1)
        df_count['win_trade'] = df_count['realized_pnl'] >= 0

        num_of_win = df_count['win_trade'].sum()
        win_rate = round(100 * num_of_win / num_of_trade, 2)

        yearly_stats = df_count.groupby(df_count.index.year).agg(
            year_pnl=('realized_pnl', 'sum'),
            year_win_count=('win_trade', 'sum'),
            year_trade_count=('realized_pnl', 'count'),
            year_start_equity_value=('equity_value', 'first'))
        yearly_stats['year_return'] = 100 * (yearly_stats['year_pnl'] / yearly_stats['year_start_equity_value'])
        yearly_stats['year_win_rate'] = 100 * (yearly_stats['year_win_count'] / yearly_stats['year_trade_count'])

        cov_return = yearly_stats['year_win_rate'].std() / yearly_stats['year_win_rate'].mean() if yearly_stats['year_win_rate'].mean() != 0 else 0
        cov_count = yearly_stats['year_trade_count'].std() / yearly_stats['year_trade_count'].mean() if yearly_stats['year_trade_count'].mean() != 0 else 0

        yearly_stats = yearly_stats.applymap(lambda x: f'{x:.2f}')
        formatted_rows = yearly_stats.apply(
            lambda
                row: f"{row.name},{'year_trade_count'}:{row['year_trade_count']},{'year_win_rate'}:{row['year_win_rate']},{'year_return'}:{row['year_return']}",
            axis=1)

        yearly_stats_string = "|".join(formatted_rows)

        ######## resample to daily #############
        if 'summary_mode' in para_combination:
            if not para_combination['summary_mode']:
                eqiuty_curve_non_summary_save_path = os.path.join('', equity_curve_folder, f'{ref_code}_non-summary-intraday.parquet')
                df.to_parquet(eqiuty_curve_non_summary_save_path)

        ########### equity value result ###############

        ### for intraday ###
        if intraday:
            df = df_count.resample('D').agg({'equity_value': 'last'})
            df = pd.concat([df_daily, df], axis=1)
            df = df[df['price'].notna()]
            df['equity_value'] = df['equity_value'].ffill()
            df['equity_value'] = df['equity_value'].bfill()

        df['equity_value'] = df['equity_value'].astype(np.int32)
        ######################

        df['running_max'] = df['equity_value'].cummax()
        df['dd_dollar'] = df['running_max'] - df['equity_value']
        df['dd_pct'] = df['dd_dollar'] / df['running_max'] * 100
        equity_mdd_dollar = df['dd_dollar'].max()
        equity_mdd_pct = df['dd_pct'].max()

        holding_period_day = (df.index[-1].date() - df.index[0].date()).days
        equity_pct_series = df['equity_value'].pct_change().dropna()
        equity_net_profit = df.at[df.index[-1], 'equity_value'] - df.at[df.index[0], 'equity_value']

        equity_return_on_capital = equity_net_profit / df.at[df.index[0], 'equity_value']
        equity_annualized_return = (np.sign(1 + equity_return_on_capital) * np.abs(1 + equity_return_on_capital)) ** (
                    365 / holding_period_day) - 1
        equity_annualized_std = equity_pct_series.std() * math.sqrt(365)
        equity_annualized_sr = (
                                           equity_annualized_return - risk_free_rate / 100) / equity_annualized_std if equity_annualized_std != 0 else 0
        equity_net_profit_to_mdd = equity_net_profit / equity_mdd_dollar if equity_mdd_dollar != 0 else 0

        equity_return_on_capital = round(100 * equity_return_on_capital, 2)
        equity_annualized_return = round(100 * equity_annualized_return, 2)
        equity_annualized_std = round(100 * equity_annualized_std, 2)
        equity_annualized_sr = round(equity_annualized_sr, 2)
        equity_net_profit_to_mdd = round(100 * equity_net_profit_to_mdd, 2)

    return_on_capital_diff = equity_annualized_return - price_annualized_return

    if intraday:
        df = df[['price', 'equity_value']]
    elif not intraday:
        df = df[['price', 'action', 'trd_side', 'equity_value']]
    df.to_parquet(eqiuty_curve_save_path)
    print('backtest is runnung...', eqiuty_curve_save_path)

    ####################################
    del para_combination['df']
    del para_combination['sec_profile']
    del para_combination['manager_list']

    cov_return = round(cov_return, 2)
    cov_count = round(cov_count, 2)
    total_commission = int(round(total_commission))

    backtest_result_dict = {
        'ref_code': ref_code,
        'num_of_trade': num_of_trade,

        'equity_net_profit': equity_net_profit,
        'equity_return_on_capital': equity_return_on_capital,
        'equity_annualized_return': equity_annualized_return,
        'equity_annualized_std': equity_annualized_std,
        'equity_annualized_sr': equity_annualized_sr,
        'equity_net_profit_to_mdd': equity_net_profit_to_mdd,
        'equity_mdd_dollar': equity_mdd_dollar,
        'equity_mdd_pct': equity_mdd_pct,

        'price_net_profit'       : price_net_profit,
        'price_return_on_capital': price_return_on_capital,
        'price_annualized_return': price_annualized_return,
        'price_annualized_std': price_annualized_std,
        'price_annualized_sr': price_annualized_sr,
        'price_net_profit_to_mdd': price_net_profit_to_mdd,
        'price_mdd_dollar': price_mdd_dollar,
        'price_mdd_pct': price_mdd_pct,

        'return_on_capital_diff': return_on_capital_diff,

        'num_of_win': num_of_win,
        'num_of_trade': num_of_trade,
        'win_rate': win_rate,
        'yearly_stats_string': yearly_stats_string,
        'cov_return': cov_return,
        'cov_count': cov_count,
        'total_commission': total_commission
    }

    para_combination.update(backtest_result_dict)
    manager_list.append(para_combination)


########################################################################################################
########################################################################################################
########################################################################################################

def get_risk_free_rate(start_date, end_date):

    if '-' in start_date:
        start_date_year = datetime.datetime.strptime(start_date, '%Y-%m-%d').year
        end_date_year = datetime.datetime.strptime(end_date, '%Y-%m-%d').year
    else:
        start_date_year = datetime.datetime.strptime(start_date, '%Y%m%d').year
        end_date_year = datetime.datetime.strptime(end_date, '%Y%m%d').year

    try:
        if end_date_year == start_date_year and end_date_year == datetime.datetime.now().year:
            risk_free_rate = get_latest_fed_fund_rate()
        else:
            risk_free_rate = get_geometric_mean_of_yearly_rate(start_date_year, end_date_year)
    except:
        risk_free_rate = 2  # if network error, set rate to 2 %
        print('Network error. Risk free rate: {:.2f} %'.format(risk_free_rate))

    return risk_free_rate


# def plot_signal_analysis(py_filename, output_folder, start_date, end_date, para_dict, signal_settings):
#     app = signals.Signals(py_filename, output_folder, start_date, end_date, para_dict, generate_backtest_output_path,
#                           signal_settings)
#
#     return app


def plot(mode, backtest_result_df = None, number_of_curves=20):

    if mode == 'equity_curves':
        app = equity_curves.Plot(backtest_result_df, number_of_curves)

    if mode == 'aggregate':
        app = aggregate.Aggregate()

    return app


def get_latest_fed_fund_rate():
    url = "https://fred.stlouisfed.org/series/FEDFUNDS"
    page = requests.get(url)
    soup = BeautifulSoup(page.content, "html.parser")

    fed_funds_rate = soup.find("span", class_="series-meta-observation-value").text
    print("Latest Federal Funds Rate:", fed_funds_rate, '%')
    # fed_funds_rate = float(fed_funds_rate) / 100
    fed_funds_rate = round(float(fed_funds_rate), 2)
    return fed_funds_rate


def get_geometric_mean_of_yearly_rate(start_year, end_year):  # backtest period
    url = "https://fred.stlouisfed.org/graph/fredgraph.csv?id=DTB3"
    response = requests.get(url)
    data = response.text.split("\n")[:-1]
    data = [row.split(",") for row in data]
    df = pd.DataFrame(data[1:], columns=data[0])
    df.columns = ["date", "risk_free_rate"]
    df["date"] = pd.to_datetime(df["date"])
    df["risk_free_rate"] = pd.to_numeric(df["risk_free_rate"], errors='coerce')
    df.dropna(subset=['risk_free_rate'], inplace=True)

    risk_free_rate_history_yearly = df.resample("A", on="date").mean()
    risk_free_rate_history_yearly = risk_free_rate_history_yearly.round(3)

    # show only start between start_year and end_year
    risk_free_rate_history_yearly = risk_free_rate_history_yearly[
        risk_free_rate_history_yearly.index.year >= start_year]
    risk_free_rate_history_yearly = risk_free_rate_history_yearly[risk_free_rate_history_yearly.index.year <= end_year]

    fed_fund_rate_geometric_mean = np.exp(np.log(risk_free_rate_history_yearly["risk_free_rate"]).mean())
    fed_fund_rate_geometric_mean = round(fed_fund_rate_geometric_mean, 2)
    print("Federal Funds Rate Geometric mean from {} to {}: {} %".format(start_year, end_year,
                                                                         fed_fund_rate_geometric_mean))

    return fed_fund_rate_geometric_mean

