#! /usr/bin/env python
import click

import glob

from simtk import unit as u
from simtk.openmm import LangevinIntegrator, Platform, XmlSerializer
from simtk.openmm.app import Simulation, StateDataReporter

from bgmol.tpl.hdf5 import HDF5Reporter
from bgmol.systems import MiniPeptide


def next_sequence():
    state_files = glob.glob("state*.xml")
    if len(state_files) == 0:
        return 0
    state_ids = [int(f.replace("state","").replace(".xml", "")) for f in state_files]
    return max(state_ids) + 1

@click.command()
@click.option("--platform", "-p", default="CUDA")
#@click.option("--state", "-s", type=click.Path(exists=True), default=None)
@click.option("--temperature", "-t", type=float, default=300.0)
@click.option("--equilibration", "-e", type=float, default=0.1, help="equilibration time in ns")
@click.option("--ns", "-n", type=float, default=20)
def main(platform, temperature, equilibration, ns):

    sequence = next_sequence()

    # settings
    timestep = 1.0*u.femtosecond
    length = ns*u.nanosecond
    n_steps = round(length / timestep)

    # system and simulation
    bpti = MiniPeptide(
        "A", 
        solvated=False,
        constraints=None,
        #forcefield=["amber14/protein.ff14SB.xml", "amber14/tip3p.xml"], 
        #nonbonded_cutoff=1.2*u.nanometer, switch_distance=1.0*u.nanometer
    )
    platform = Platform.getPlatformByName(platform)
    platform_properties = {"DeviceIndex":'0', 'Precision': 'mixed'} if platform == "CUDA" else {}
    simulation = Simulation(
        bpti.topology,
        bpti.system,
        LangevinIntegrator(temperature*u.kelvin, 1/u.picosecond, timestep),
        platform,
        platform_properties,
    )

    # restart
    print(f"Starting from Sequence {sequence}.")
    ctx = simulation.context
    if sequence > 0:
        simulation.loadState(f"state{sequence-1}.xml")
    else:
        ctx = simulation.context
        ctx.setPositions(bpti.positions)
        ctx.setVelocitiesToTemperature(temperature*u.kelvin)

    if sequence == 0:
        print("Minimization")
        simulation.minimizeEnergy(maxIterations=1000)
        print(f"Equilibration")
        simulation.step(int(equilibration*u.nanosecond/timestep))

    simulation.reporters.append(StateDataReporter(
        f"data{sequence}.txt", 1000,
        step=True, time=True, potentialEnergy=True,
        kineticEnergy=True, temperature=True,
        volume=True, density=True)
    )
    simulation.reporters.append(HDF5Reporter(f"traj{sequence}.h5", 1000, forces=True))
    # run
    print(f"Running {n_steps} steps")
    simulation.step(n_steps)

    # checkpoint
    ctx.getState(getPositions=True, getVelocities=True)
    simulation.saveState(f"state{sequence}.xml")


if __name__ == "__main__":
    main()

