import warnings
warnings.simplefilter("ignore", UserWarning)

from dash import Dash, dcc, html, Input, Output, State, ALL
import dash_daq as daq
from .components import *
import time

import pandas as pd
import dash_dangerously_set_inner_html


pd.set_option('display.max_rows',None)
pd.set_option('display.max_columns',None)
pd.set_option('display.width',1024)

class Aggregate:

    chart_bg = '#1f2c56'
    df_show_agg = pd.DataFrame()
    df_ready = False

    init = True

    def __new__(self):

        self.chart_bg   = '#1f2c56'
        self.valid_colour = 'DeepSkyBlue'

        components       = Components(0)
        empty_line_chart = components.empty_line_chart()

        self.df_show_agg, checklist_div, = components.aggregate_df()
        self.df_equity_dict = {}
        self.fig_line_dict = {}
        self.aggregate_performance_dict = {}

        app = Dash(__name__, external_stylesheets=[dbc.themes.SUPERHERO], suppress_callback_exceptions=True)

        my_css_data = """body {
          background-color: #1a2245;
        }
                /* width */
        ::-webkit-scrollbar {
          width: 10px !important;
          display: block !important;
        }
        
        /* Track */
        ::-webkit-scrollbar-track {
          background: #1f2c56 !important;
          border-radius: 10px !important;
          display: block !important;
        }
        
        
        /* Handle */
        ::-webkit-scrollbar-thumb {
          background: #154360;
          border-radius: 10px;
        }
        """
        innerHtmlText = "<style>%s</style>" % my_css_data

        app.layout = html.Div([

            dash_dangerously_set_inner_html.DangerouslySetInnerHTML(innerHtmlText),

            dbc.Row([
                # Left Column
                dbc.Col(html.Div([

                    html.Div(style={'height': '10px', }),

                    html.Div(children=html.Div([
                        html.Div(style={'height': '10px', }),

                        html.Div('Equity Curves', style={'color': 'DeepSkyBlue', 'font-size': '17px'}),

                        html.Div(style={'height': '10px', }),

                        html.Div(id='checklist-container', children=checklist_div,
                                 style={
                                        'maxHeight': '625px',
                                        'minHeight': '625px',
                                        'overflow-y': 'scroll',
                                        'overflow-x': 'hidden',
                                        }),

                    ],), style={'padding': '5px 5px','padding-left':'15px', 'border-radius': '5px',
                                'font-size': '13px','background-color': self.chart_bg},
                    ),

                ],), style={'padding': '0', 'padding-left': '5px',}, width=3),

                # Right Column
                dbc.Col(html.Div([

                    html.Div(style={'height': '15px', }),

                    html.Div([

                        html.Div(style={'height': '10px', }),

                        html.Div(
                            [
                                html.Div('Aggregate', id='aggregate_string', style={'color': self.valid_colour, 'margin-left': '25px'}),
                                daq.BooleanSwitch(on=True, color=self.valid_colour, id='aggregate_boolean'),
                            ],
                            style={
                                'display': 'flex',
                                'align-items': 'center',
                                'gap': '10px',  # Adds some spacing between text and the switch
                            },
                        ),

                        # html.Div(html.Div('Aggregate', id='aggregate_string', style={'color': self.valid_colour}),
                        #          style={'vertical-align': 'top','margin-left': '5px',
                        #                 'position': 'relative', 'top': '0.08em',
                        #                 'display': 'inline-block'}),
                        #
                        # html.Div(daq.BooleanSwitch(on=True, color=self.valid_colour, id='aggregate_boolean',
                        #                            style={'height': '5px', }),
                        #          style={'margin-left': '5px', 'display': 'inline-block'}),

                        # html.Div(html.Div('Normalized', id='normalized_string', style={'color': 'Grey'}),
                        #          style={'margin-left': '5px', 'vertical-align': 'top',
                        #                 'position': 'relative', 'top': '0.08em',
                        #                 'display': 'inline-block'}),
                        #
                        # html.Div(daq.BooleanSwitch(on=False, color=valid_colour, id='normalized_boolean',
                        #                            style={'height': '5px', }),
                        #          style={'margin-left': '5px', 'display': 'inline-block'}),
                    ]),

                    html.Div(style={'height': '10px', }),

                    html.Div(children=html.Div([

                        dbc.Row([
                            # dbc.Col(html.Div(style={'height': '10px', }),

                            dbc.Col([

                                html.Div(style={'height': '15px', }),

                                html.Div(id='performance_div',
                                         style={'padding-left': '10px', 'font-size': '14px'}),

                                html.Div(style={'height': '15px', }),

                            ],style={'padding':'0px', 'padding-left': '28px',},width=3),

                            dbc.Col([

                                html.Div(style={'height': '5px', }),

                                html.Div(id='chart_area',
                                         children=dcc.Graph(id='line_chart',
                                                            figure=empty_line_chart)),

                                html.Div(style={'height': '20px', }),

                            ],style={'padding':'0px','padding-right': '35px',},width=9),
                        ]),

                    ]), style={'padding': '0px','border-radius': '5px',
                               # 'background-color': chart_bg
                               }),

                ]), style={'padding': '0', 'padding-left': '5px'}, width=9),

            ]),

        ], style={'width': '1500px', 'margin': 'auto', 'padding': '0px', 'color': 'white'})


        # For initial curve list
        @app.callback(
            Output('checklist-container', 'children'),
            Input('checklist-container', 'children'),
        )
        def init_checklist(checklist_container):
            self.df_ready = False
            self.df_show_agg, checklist_div = components.aggregate_df()
            self.df_ready = True

            return checklist_div


        @app.callback(
            Output('line_chart', 'figure'),
            #Output('normalized_string', 'style'),
            Output('aggregate_string', 'style'),
            Output('performance_div', 'children'),
            Input('curve_checklist', 'value'),
            #Input('normalized_boolean','on'),
            Input('aggregate_boolean', 'on')
        )
        def aggregate_chart(curve_checklist, aggregate_boolean):

            if self.init:
                self.init = False
                fig_line = px.line()
                fig_line.update_layout(title={'text': ''})
                fig_line.update_xaxes(showline=True, zeroline=False, linecolor='white', gridcolor='rgba(0, 0, 0, 0)')
                fig_line.update_yaxes(showline=True, zeroline=False, linecolor='white', gridcolor='rgba(0, 0, 0, 0)')
                fig_line.update_layout(plot_bgcolor='#1a2245', paper_bgcolor='#1a2245', height=650,
                                       margin=dict(l=0, r=25, t=60, b=0),
                                       showlegend=True,
                                       font={"color": "white", 'size': 10.5}, yaxis={'title': ''},
                                       xaxis={'title': ''}
                                       )
                #normalized_style = {}
                aggregate_style = {}
                aggregate_performance = html.Div()
                return fig_line, aggregate_style, aggregate_performance

            while not self.df_ready:
                time.sleep(0.1)

            curve_checklist.sort()
            #all_dict_key = str(curve_checklist) + ' |agg_' + str(aggregate_boolean).lower() + '|normalize_' + str(normalized_boolean).lower()
            all_dict_key = str(curve_checklist) + ' |agg_' + str(aggregate_boolean).lower()

            title_text = ''

            if aggregate_boolean:
                aggregate_style = {'color': self.valid_colour, 'margin-left': '25px'}
                title_text += 'Aggregate '
            else:
                aggregate_style = {'color': 'Grey', 'margin-left': '25px'}

            title_text += 'Equity Curves'

            #########################################################################

            if all_dict_key in self.aggregate_performance_dict:
                aggregate_performance = self.aggregate_performance_dict[all_dict_key]
            else:
                df_show_agg_checked = self.df_show_agg.loc[curve_checklist]
                # folder_list         = list(df_show_agg_checked['folder'])
                # backtest_name_list  = list(df_show_agg_checked['backtest_name'])
                # equity_path_list    = list(df_show_agg_checked['equity_path'])

                # Aggregate Equity
                df_list = []

                for i, row in df_show_agg_checked.iterrows():
                    #parameters  = df_show_agg_checked.at[i+1, 'parameters']
                    ref_code    = df_show_agg_checked.at[i, 'ref_code']
                    equity_path = df_show_agg_checked.at[i, 'equity_path']


                    if ref_code in self.df_equity_dict:
                        df = self.df_equity_dict[ref_code]
                    else:
                        df = pd.read_parquet(equity_path)
                        self.df_equity_dict[ref_code] = df

                    df_current = pd.DataFrame()

                    df_current[f'{row['folder']}_{row['backtest_name']}_{i}'] = df['equity_value']
                    df_current.index = df.index
                    df_list.append(df_current.copy())

                df_all_curves = pd.concat(df_list, axis=1).sort_index()
                df_all_curves = df_all_curves.ffill()
                df_all_curves = df_all_curves.bfill()
                df_all_curves["Aggregate_Equity"] = df_all_curves.sum(axis=1)

                # Performance
                df_performance = pd.DataFrame(list(df_show_agg_checked['performance']))
                df_performance = df_performance.drop(
                    columns=['net_profit_to_mdd', 'mdd_dollar', 'mdd_pct', 'return_on_capital', 'sharpe_ratio'])

                total_dict = {}
                for element in df_performance.columns:
                    if 'Win Rate' not in element and 'Return' not in element:
                        total_dict[element] = df_performance[element].sum()

                total_dict['return_on_capital'] = total_dict['net_profit'] / total_dict['initial_capital']
                total_dict['win_rate'] = total_dict['num_of_win'] / total_dict['num_of_trade']

                df_agg = df_all_curves['Aggregate_Equity']

                running_max = df_agg.cummax()
                drawdown = running_max - df_agg
                mdd = drawdown.max()
                drawdown_percentage = drawdown / running_max * 100
                mdd_pct = round(drawdown_percentage.max() * 0.01, 5)

                total_dict['mdd_dollar'] = mdd
                total_dict['mdd_pct'] = mdd_pct
                total_dict['net_profit_to_mdd'] = total_dict['net_profit'] / total_dict['mdd_dollar']

                holding_period_day = (df_all_curves.index[-1] - df_all_curves.index[0]).days

                equity_value_pct_series = df_all_curves["Aggregate_Equity"].pct_change()
                equity_value_pct_series = equity_value_pct_series.dropna()

                return_on_capital = total_dict['return_on_capital']

                annualized_return = (1 + return_on_capital) ** (365 / holding_period_day) - 1

                annualized_std = equity_value_pct_series.std() * math.sqrt(365)

                if annualized_std > 0:
                    annualized_sr = annualized_return / annualized_std
                else:
                    annualized_sr = 0

                total_dict['annualized_return'] = annualized_return
                total_dict['annualized_std'] = annualized_std
                total_dict['annualized_sr'] = annualized_sr

                start_date_year = df_all_curves.index[0].year
                end_date_year = df_all_curves.index[-1].year

                year_list = list(range(start_date_year, end_date_year + 1))

                # Performance by year
                df_year = pd.DataFrame()
                df_year['equity_value'] = df_all_curves['Aggregate_Equity'].copy()
                df_year['date'] = df_agg.index.copy()
                df_year['date'] = pd.to_datetime(df_year['date'], format='%Y-%m-%d')
                df_year['year'] = pd.DatetimeIndex(df_year['date']).year

                first_equity_value = 0
                for year in year_list:
                    if first_equity_value == 0:
                        first_equity_value = df_year.loc[df_year['year'] == year].iloc[0].equity_value
                    last_equity_value            = df_year.loc[df_year['year'] == year].iloc[-1].equity_value
                    yearly_return                = (last_equity_value - first_equity_value) / first_equity_value
                    total_dict[f'{year}_return'] = yearly_return
                    first_equity_value           = last_equity_value


                ######### *********** ##########
                try:
                    risk_free_rate = plotguy.get_risk_free_rate(start_date, end_date)
                except:
                    risk_free_rate = 1
                aggregate_performance = components.aggregate_performance(total_dict, year_list, risk_free_rate)
                self.aggregate_performance_dict[all_dict_key] = aggregate_performance
                #################################

            # Generate Chart
            if all_dict_key in self.fig_line_dict:
                fig_line = self.fig_line_dict[all_dict_key]

            else:
                if aggregate_boolean:
                    fig_line = px.line()
                    fig_line.update_layout(title={'text': title_text })
                    fig_line.update_xaxes(showline=True, zeroline=False, linecolor='white', gridcolor='rgba(0, 0, 0, 0)')
                    fig_line.update_yaxes(showline=True, zeroline=False, linecolor='white', gridcolor='rgba(0, 0, 0, 0)')
                    fig_line.update_layout(plot_bgcolor='#1a2245', paper_bgcolor='#1a2245', height=650,
                                           margin=dict(l=0, r=25, t=60, b=0),
                                           showlegend=True,
                                           font={"color": "white", 'size': 10.5}, yaxis={'title': ''},
                                           xaxis={'title': ''}
                                           )
                    fig_line.add_trace(go.Scatter(mode='lines', # hovertemplate=hovertemplate,
                                                  x=df_all_curves.index, y=df_all_curves['Aggregate_Equity'],
                                                  line=dict(color=self.valid_colour, width=1.5), name='Aggregate'), )

                    self.fig_line_dict[all_dict_key] = fig_line

                else:
                    fig_line = px.line()
                    fig_line.update_layout(title={'text': title_text})
                    fig_line.update_xaxes(showline=True, zeroline=False, linecolor='white', gridcolor='rgba(0, 0, 0, 0)')
                    fig_line.update_yaxes(showline=True, zeroline=False, linecolor='white', gridcolor='rgba(0, 0, 0, 0)')
                    fig_line.update_layout(plot_bgcolor='#1a2245', paper_bgcolor='#1a2245', height=650,
                                           margin=dict(l=0, r=25, t=60, b=0),
                                           showlegend=True,
                                           font={"color": "white", 'size': 10.5}, yaxis={'title': ''},
                                           xaxis={'title': ''},
                                           )
                    columns = list(df_all_curves.columns)[:-1]
                    for curve_number, column in zip(curve_checklist,columns):
                        fig_line.add_trace(go.Scatter(mode='lines',  # hovertemplate=hovertemplate,
                                                      x=df_all_curves.index, y=df_all_curves[column],
                                                      line=dict(color=self.df_show_agg.loc[curve_number].line_colour, width=1.5), name=f'Curve {str(curve_number).zfill(3)}'), )

                    self.fig_line_dict[all_dict_key] = fig_line

            return fig_line, aggregate_style, aggregate_performance



        return app