#! /usr/share/python
import argparse
import MDAnalysis as mda
from CodeEntropy.FunctionCollection import EntropyFunctions as EF
from CodeEntropy.ClassCollection import DataContainer as DC
from CodeEntropy.IO import MDAUniverseHelper as MDAHelper
import pandas as pd
import numpy as np
from datetime import datetime
from CodeEntropy.ClassCollection.PoseidonClass import Poseidon

try:
	parser = argparse.ArgumentParser(description="""
	CodeEntropy-POSEIDON is a tool to compute entropy using the multiscale-cell-correlation (MCC) theory and force/torque covariance methods with the ablity to compute solvent entropy. 
	Version:
		0.3.1;

	Author: 
		Arghya Chakravorty (arghya90),
		Jas Kalayan (jkalayan);

	Output:
		*.csv = results from different calculateion,
		*.pkl - Pickled reduced universe for further analysis,
		*.out - detailed output such as matrix and spectra""")
	
	
	# group = parser.add_mutually_exclusive_group(required=True)
	# group.add_argument('-f', '--top_traj_file',
	# 					dest="filePath",
	# 					action='store',
	# 					nargs='+',
	# 					default=None,
	# 					help="Path to Structure/topology file (AMBER PRMTOP or GROMACS TPR) followed by Trajectory file(s) (AMBER NETCDF or GROMACS TRR) Required, Mutually exclusive with pickle----pickle")
	# group.add_argument('-i', '--pickle',
	# 					dest="picklePath",
	# 					action='store',
	# 					type=str,
	# 					default=None,
	# 					help="Pickled MDAnalyiss.Universe for unsupported format Required but Mutually exclusive with --top_traj_file")
	parser.add_argument('-f', '--top_traj_file',
						required=True,
						dest="filePath",
						action='store',
						nargs='+',
						help="Path to Structure/topology file (AMBER PRMTOP, GROMACS TPR which contains topology and dihedral information) followed by Trajectory file(s) (AMBER NETCDF or GROMACS TRR). Required.")
	parser.add_argument('-l', '--selectString', 
						action='store',
						dest="selectionString",  
						type=str,
						default='all',
						help='Selection string for CodeEntropy such as protein or resid, refer to MDAnalysis.select_atoms for more information.')
	parser.add_argument('-b', '--begin', 
						action="store", 
						dest="start", 
						help="Start analysing the trajectory from this frame index. Defaults to 0", 
						default=0, 
						type= int)
	parser.add_argument('-e', '--end', 
						action="store", 
						dest="end", 
						help="Stop analysing the trajectory at this frame index. Defaults to -1 (end of trajectory file)", 
						default=-1,
						type=int)
	parser.add_argument('-d', '--step', 
						action="store", 
						dest="step", 
						help="interval between two consecutive frames to be read index. Defaults to 1", 
						default=1,
						type=int)
	parser.add_argument('-k', '--tempra', 
						action="store", 
						dest="temp", 
						help="Temperature for entropy calculation (K). Default to 298.0 K", 
						default=298.0, 
						type=float)
	parser.add_argument('-t', '--thread', 
						action="store", 
						dest="thread", 
						help="How many multiprocess to use. Default 1 for single core execution.", 
						default=1, 
						type=int)
	parser.add_argument("-o","--out",
						action="store",
						dest  ="outFile",
						default="outfile.out",
						help   ="Name of the file where the text format output will be written. Default: outfile.out")
	parser.add_argument("-v","--csvout",
						action="store",
						dest  ="csvOutFile",
						default="outfile.csv",
						help   ="Name of the file where the total entropy output will be written. Default: outfile.csv")
	parser.add_argument("-r","--resout",
						action="store",
						dest  ="rescsvOutFile",
						default="res_outfile.csv",
						help   ="Name of the file where the residue entropy output will be written. Default: res_outfile.csv")

	parser.add_argument("-m", "--mout",
						action="store",
						dest  ="moutFile",
						default=None,
						help   ="Name of the file where certain matrices will be written (default: None).")

	parser.add_argument("-n", "--nmd",
						action="store",
						dest  = "nmdFile",
						default=None,
						help   = "Name of the file where VMD compatible NMD format files with mode information will be printed (default: None).")
	parser.add_argument('-a', '--rotationalaxis', 
						dest = "rotationalAxis",
						action='store', 
						nargs='+',
						default=['C', 'CA', 'N'], 
						help="The 3 atom name in each residue for rotational axis. Default for protein ['C', 'CA', 'N'] axis")
	# disabled for now since POSEIDON don't have this option
	# parser.add_argument("-f", "--fscale",
	# 					action="store",
	# 					dest="fScale",
	# 					default = 1.0,
	# 					type=float,
	# 					help = "Scale the atomic forces by this factor. Default 1.0")

	# parser.add_argument("-t", "--tscale",
	# 					action = "store",
	# 					dest = "tScale",
	# 					default = 1.0,
	# 					type=float,
	# 					help = "Scale the atomic torques by this factor. Defaule 1.0")
	parser.add_argument('-c', '--cutShell', 
						action='store',
						dest="cutShell",
						default=None, 
						type=float,
						help='include cutoff shell analysis, add cutoff distance in angstrom Default None will ust the RAD Algorithm')
	parser.add_argument('-p', '--pureAtomNum', 
						action='store',
						dest="puteAtomNum", 
						default=1, 
						type=int,
						help='Reference molecule resid for system of pure liquid. Default to 1')
	parser.add_argument('-x', '--excludedResnames', 
						dest="excludedResnanes",
						action='store', 
						nargs='+',
						default=None, 
						help='exclude a list of molecule names from nearest non-like analysis. Default: None. Multiples are gathered into list.')
	parser.add_argument('-w', '--water',
						dest = "waterResnames",
						action='store', 
						default='WAT',
						nargs='+', 
						help='resname for water molecules. Default: WAT. Multiples are gathered into list.')
	parser.add_argument('-s', '--solvent',
						dest="solventResnames", 
						action='store', 
						nargs='+',
						default=None, 
						help='include resname of solvent molecules (case-sensitive) Default: None. Multiples are gathered into list.')
	parser.add_argument("--wm",
						action="store_true",
						dest  ="doMolEntropy",
						default=False,
						help  ="Do entropy calculation at whole molecule level (The whole molecule is treated as one single bead.)")
	parser.add_argument("--res",
						action="store_true",
						dest  ="doResEntropy",
						default=False,
						help  ="Do entropy calculation at residue level (A residue as a whole represents a bead.)")
	parser.add_argument("--uatom",
						action="store_true",
						dest  ="doUAEntropy",
						default=False,
						help  ="Do entropy calculation at united atom level (A heavy atom and its covalently bonded H-atoms for an united atom and represent a bead.)")
	parser.add_argument("--topog",
						action="store",
						dest  = "topogEntropyMethod",
						type = int,
						default = 0,
						help = "Compute the topographical entropy using\n 1 : pLogP method \n 2: Corr. pLogP method \n 5: Corr. pLogP after adaptive enumeration of states. Default: 0 = no topographical analysis")
	parser.add_argument("--solwm",
						action="store_true",
						dest  ="doSolMolEntropy",
						default=False,
						help  ="Do solution entropy calculation at whole molecule level (The whole molecule is treated as one single bead.)")
	parser.add_argument("--solres",
						action="store_true",
						dest  ="doSolResEntropy",
						default=False,
						help  ="Do solution entropy calculation at residue level (A residue as a whole represents a bead.)")
	parser.add_argument("--soluatom",
						action="store_true",
						dest  ="doSolUAEntropy",
						default=False,
						help  ="Do solution entropy calculation at united atom level (A heavy atom and its covalently bonded H-atoms for an united atom and represent a bead.)")
	parser.add_argument("--solContact",
						action="store_true",
						dest  ="doSolContact",
						default=False,
						help  ="Do solute contact calculation")

	args = parser.parse_args()
except argparse.ArgumentError:
	print('Command line arguments are ill-defined, please check the arguments')
	raise

						


#['moleculeLevel', 'residLevel_resname', 'atomLevel', 'soluteContact']

############## REPLACE INPUTS ##############
print("printing all input")
for arg in vars(args):
    print(' {} {}'.format(arg, getattr(args, arg) or ''))

startTime = datetime.now()
filePath = args.filePath

outfile = args.outFile
tScale = 1.0
fScale = 1.0
temper = args.temp
selection_string = args.selectionString
start = args.start
end = args.end
step = args.step
thread = args.thread
axis_list = args.rotationalAxis
moutFile = args.moutFile
nmdFile = args.nmdFile
rescsvOutFile = args.rescsvOutFile
csvOutFile = args.csvOutFile
doMolEntropy = args.doMolEntropy
doResEntropy = args.doResEntropy
doUAEntropy = args.doUAEntropy
topogEntropyMethod = args.topogEntropyMethod

doSolMolEntropy = args.doSolMolEntropy
doSolResEntropy = args.doSolResEntropy
doSolUAEntropy = args.doSolUAEntropy
doSolContact = args.doSolContact
cutShell = args.cutShell
puteAtomNum = args.puteAtomNum
excludedResnanes = args.excludedResnanes
waterResnames = args.waterResnames
solventResnames = args.solventResnames

tprfile = filePath[0]
trrfile = filePath[1:]
u = mda.Universe(tprfile, trrfile)

results_df = pd.DataFrame(columns=['Method and Level','Type', 'result'])

reduced_frame = MDAHelper.new_U_select_frame(u,  start, end, step)

reduced_frame_name = f"{(len(reduced_frame.trajectory))}_frame_dump"

reduced_frame_filename = MDAHelper.write_universe(reduced_frame, reduced_frame_name)

reduced_atom = MDAHelper.new_U_select_atom(reduced_frame, selection_string)

reduced_atom_name = f"{len(reduced_atom.trajectory)}_frame_dump_strip_solvent"

reduced_atom_filename = MDAHelper.write_universe(reduced_atom, reduced_atom_name)

dataContainer = DC.DataContainer(reduced_atom)

if doMolEntropy:
	wm_entropyFF, wm_entropyTT = EF.compute_entropy_whole_molecule_level(
		arg_hostDataContainer = dataContainer,
		arg_outFile = outfile,
		arg_selector = "all", 
		arg_moutFile = moutFile,
		arg_nmdFile = nmdFile,
		arg_fScale = fScale,
		arg_tScale = tScale,
		arg_temper = temper,
		arg_verbose = 5
	)

	print(f"wm_entropyFF = {wm_entropyFF}")
	print(f"wm_entropyTT = {wm_entropyTT}")
	newRow = pd.DataFrame({'Method and Level': ['Whole molecule', 'Whole molecule'],
							'Type':['FF Entropy (J/mol/K)', 'TT Entropy (J/mol/K)'],
							'result': [wm_entropyFF, wm_entropyTT],})
	results_df = pd.concat([results_df, newRow], ignore_index=True)

if doResEntropy:
	res_entropyFF, res_entropyTT = EF.compute_entropy_residue_level(
		arg_hostDataContainer = dataContainer,
		arg_outFile = outfile,
		arg_selector = 'all', 
		arg_moutFile = moutFile,
		arg_nmdFile = nmdFile,
		arg_fScale = fScale,
		arg_tScale = tScale,
		arg_temper = temper,
		arg_verbose = 5,
		arg_axis_list = axis_list,
	)
	print(res_entropyFF)
	print(res_entropyTT)
	newRow = pd.DataFrame({'Method and Level': ['Residue', 'Residue'],
							'Type':['FF Entropy (J/mol/K)', 'TT Entropy (J/mol/K)'],
							'result': [res_entropyFF, res_entropyTT],})
	results_df = pd.concat([results_df, newRow], ignore_index=True)

if doUAEntropy:
	if thread <= 1:
		UA_entropyFF, UA_entropyTT, res_df = EF.compute_entropy_UA_level(
			arg_hostDataContainer = dataContainer,
			arg_outFile = outfile,
			arg_selector = 'all', 
			arg_moutFile = moutFile,
			arg_nmdFile = nmdFile,
			arg_fScale = fScale,
			arg_tScale = tScale,
			arg_temper = temper,
			arg_verbose = 1,
			arg_axis_list = axis_list,
			arg_csv_out= rescsvOutFile,
		)
		print(f"UA_entropyFF = {UA_entropyFF}")
		print(f"UA_entropyTT = {UA_entropyTT}")
		newRow = pd.DataFrame({'Method and Level': ['United atom', 'United atom'],
								'Type':['FF Entropy (J/mol/K)', 'TT Entropy (J/mol/K)'],
								'result': [UA_entropyFF, UA_entropyTT],})
		results_df = pd.concat([results_df, newRow], ignore_index=True)
	elif thread > 1:
		UA_entropyFF, UA_entropyTT, res_df = EF.compute_entropy_UA_level_multiprocess(
			arg_hostDataContainer = dataContainer,
			arg_outFile = outfile,
			arg_selector = 'all', 
			arg_moutFile = moutFile,
			arg_nmdFile = nmdFile,
			arg_fScale = fScale,
			arg_tScale = tScale,
			arg_temper = temper,
			arg_verbose = 1,
			arg_csv_out= rescsvOutFile,
			arg_axis_list = axis_list,
			arg_thread= thread,
		)
		print(f"UA_entropyFF = {UA_entropyFF}")
		print(f"UA_entropyTT = {UA_entropyTT}")
		newRow = pd.DataFrame({'Method and Level': ['United atom', 'United atom'],
								'Type':['FF Entropy (J/mol/K)', 'TT Entropy (J/mol/K)'],
								'result': [UA_entropyFF, UA_entropyTT],})
		results_df = pd.concat([results_df, newRow], ignore_index=True) 

if topogEntropyMethod != 0:
	if topogEntropyMethod == 1:
		result_entropy0_SC = EF.compute_topographical_entropy0_SC(
			arg_hostDataContainer = dataContainer, 
			arg_selector = "all",
			arg_outFile = outfile, 
			arg_verbose = 5
		)

		print(f"result_entropy0_SC = {result_entropy0_SC}")

		result_entropy0_BB = EF.compute_topographical_entropy0_BB(
			arg_hostDataContainer = dataContainer, 
			arg_selector = "all",
			arg_outFile = outfile, 
			arg_verbose = 5
		) 

		print(f"result_entropy0_BB = {result_entropy0_BB}")

		newRow = pd.DataFrame({'Method and Level': ['Topographical', 'Topographical'],
								'Type':['Total SC Topog. Entropy pLogp', 'Total BB Topog. Entropy pLogp'],
								'result': [result_entropy0_SC, result_entropy0_BB],})
		results_df = pd.concat([results_df, newRow], ignore_index=True)

	elif topogEntropyMethod == 2:
		result_entropy1_SC = EF.compute_topographical_entropy1_SC(
			arg_hostDataContainer = dataContainer, 
			arg_selector = "all",
			arg_outFile = outfile, 
			arg_verbose = 5
		)

		print(f"result_entropy1_SC= {result_entropy1_SC}")


		result_entropy1_BB = EF.compute_topographical_entropy1_BB(
			arg_hostDataContainer = dataContainer, 
			arg_selector = "all",
			arg_outFile = outfile, 
			arg_verbose = 5
		) 
		print(f"result_entropy1_BB= {result_entropy1_BB}")

		newRow = pd.DataFrame({'Method and Level': ['Topographical', 'Topographical'],
								'Type':['Total SC Topog. Entropy (corr. pLogP)', 'Total BB Topog. Entropy (corr. pLogP)'],
								'result': [result_entropy1_SC, result_entropy1_BB],})
		results_df = pd.concat([results_df, newRow], ignore_index=True)

	# elif topogEntropyMethod == 3:
	# 	# Demanding computation if large amount of Dihedral
	# 	result_entropy3 = EF.compute_topographical_entropy_method3(
	# 		arg_hostDataContainer = dataContainer, 
	# 		arg_selector = "all",
	# 		arg_outFile = outfile, 
	# 		arg_verbose = 5
	# 	) 
	# 	print(f"result_entropy3 = {result_entropy3}")

	# 	newRow = pd.DataFrame({'Method and Level': ['Topographical ', ],
	# 							'Type':['Phi Coeff',],
	# 							'result': [result_entropy3,],})
	# 	results_df = pd.concat([results_df, newRow], ignore_index=True)
	# elif topogEntropyMethod == 4:
	# 	#work in progress
	# 	result_entropy4 = EF.compute_topographical_entropy_method4(
	# 	    arg_hostDataContainer = dataContainer, 
	# 	    arg_selector = "all",
	# 	    arg_outFile = outfile, 
	# 	    arg_verbose = 5
	# 	)
	# 	newRow = pd.DataFrame({'Method and Level': ['Topographical', ],
	# 							'Type':['Phi Coeff.',],
	# 							'result': [result_entropy4,],})
	# 	results_df = pd.concat([results_df, newRow], ignore_index=True)
	# 	print(f"result_entropy4 = {result_entropy4}")

	elif topogEntropyMethod == 5:
		result_entropyAEM = EF.compute_topographical_entropy_AEM(
			arg_hostDataContainer = dataContainer, 
			arg_selector = "all",
			arg_outFile = outfile, 
			arg_verbose = 5
		)
		newRow = pd.DataFrame({'Method and Level': ['Topographical AEM', ],
								'Type':['Total Topog. Entropy (AEM)',],
								'result': [result_entropyAEM,],})
		results_df = pd.concat([results_df, newRow], ignore_index=True)

		print(f"result_entropyAEM = {result_entropyAEM}")



print(f"total time {datetime.now() - startTime}")
results_df = results_df.replace(np.nan, 'nan')
results_df.to_csv(csvOutFile, index=False)

level_list = []

if doSolMolEntropy:
	level_list.append('moleculeLevel')
if doSolResEntropy:
	level_list.append('residLevel_resname')
if doSolUAEntropy:
	level_list.append('atomLevel')
if doSolContact:
	level_list.append('soluteContacts')	

if len(level_list) > 0:
	#none because the frame is already striped
	poseidon_object = Poseidon(
		container=u, 
		start=start, 
		end=end,
		step=step,
		pureAtomNum=puteAtomNum,
		cutShell=cutShell,
		water=waterResnames, 
		excludedResnames=excludedResnanes,
		verbose=True
	)
	poseidon_object.run_analysis(
		temperature=temper,
		level_list=level_list,
		solvent=solventResnames,
		water=waterResnames, 
		verbose=True,
		weighting = None,
		#forceUnits = "kJ"
	)