#!/usr/bin/python
# -*- coding: utf-8 -*-
#--------------------------------------------------------------------------------
#process_pdb.py v0.1, Copyright Bjoern Olausson
#--------------------------------------------------------------------------------
#This program is free software; you can redistribute it and/or modify
#it under the terms of the GNU General Public License as published by
#the Free Software Foundation; either version 2 of the License, or
#(at your option) any later version.
#
#This program is distributed in the hope that it will be useful,
#but WITHOUT ANY WARRANTY; without even the implied warranty of
#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
#GNU General Public License for more details.
#
#To view the license visit
#http://www.gnu.org/licenses/old-licenses/gpl-2.0.html
#or write to
#Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
#--------------------------------------------------------------------------------
#--------------------------------------------------------------------------------
#
#This tool gets extended as I need som more functionality.
#
# Currently it is intended to
# -fix missing HETATM record for spcific residue names (see CONFIG section)
# -fix Engineered residues (see CONFIG section)
# -removes all residues not defined in the CONFIG section
#
#I know I should have used dictionaries ;-)
#################################CONFIG############################
#Valid Aminoacids
AA = ( 'ALA', 'ARG', 'ASN', 'ASP', 'ASX', 'CYS', 'GLU', 'GLN', 'GLX', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRY', 'TYR', 'VAL' )
#VALID Nucleoacids
NA = ( 'ADE', 'CYT', 'GUA', 'THY', 'URA' )

#Heteroatoms _I_ like to be valid none (every Tuple in HET needs a counterpart in HETID)
HET = ( 'HOH', 'MG', 'TIP3' )
#Heteroatom ID
HETID = ( 'W', 'M', 'W' )

#Engineered residues (every Tuple in HET needs a counterpart in HETID)
EN = ( 'CSP', 'GMA', 'HIP', 'MLY', 'MSE', 'PDS', 'PHL', 'PTR', 'SEP', 'TPO' )
#Replacement for engineered residues according to charmm-gui
ENR = ( 'CYS', 'GLU2', 'HSD', 'LYS', 'MET', 'ASP', 'ASP', 'TYR1', 'SER1', 'THR1' )

#Pprotonizable residues
PRES = ( 'ASP', 'GLU', 'LYS', 'HIS' )

#Set protonation state for HIS | Protonated His --> HSP | neutral HIS, proton on ND1 --> HSD | neutral His, proton on NE2 --> HSE
HISTIDIN = "HSD"

#Misc. residues to save in seperate file when splitting the PDP
MISC = ( 'CYT' )
###################################################################

import sys, os, shutil, string, glob, subprocess, itertools, time, tempfile

from Bio.PDB import PDBIO
from Bio.PDB.PDBIO import Select
pdbwrite = PDBIO()
from Bio.PDB.PDBParser import PDBParser
pdbparse = PDBParser(PERMISSIVE=1)

from optparse import OptionParser
parser = OptionParser()

from sets import Set

parser.add_option(
        "-i",
	"--input",
        action="store",
        dest="I",
        help="PDB-File to mod",
        metavar="INPUTFILE"
)

parser.add_option(
        "-r",
	"--renumber",
        action="store",
        dest="R",
        help="""Use the keywords "atm" "res" "all" to renumber only atoms, residues or both """,
        metavar="RENUM"
)

parser.add_option(
        "-c",
        "--clean",
        action="store_true",
        dest="C",
        help="Clean the PDB",
        metavar="CLEAN"
)

parser.add_option(
        "-w",
        "--water",
        action="store_true",
        dest="W",
        help="Fix PDB water (HOH) to be charmm complient (TIP3)",
        metavar="WATER",
)

parser.add_option(
        "-s",
        "--split",
        action="store_true",
        dest="S",
        help="Split PDB file with multible chains into multible files each containing a single chain. Additionaly all HETATMs are stored in a seperate file. If split fails, try to clean (-c) the pdb first",
        metavar="WATER",
)

parser.add_option(
        "-p",
        "--protlist",
        action="store_true",
        dest="P",
        help="List all protonizable residues in the Protein (ASP -> [ASPP], GLU -> [GLUP], LYS -> [LSN] , HIS -> [HSD, HSE, HSP])",
        metavar="PROTABLE",
)

parser.add_option(
        "-m",
        "--chm",
        action="store_true",
        dest="M",
        help="Convert atomtypes to CHARMM atomtypes",
        metavar="CHM",
)

parser.add_option(
        "-t",
        "--test",
        action="store_true",
        dest="T",
        help="Just for testing stuff",
        metavar="TEST",
)
# instruct optparse to parse the program's command line
(options, args) = parser.parse_args()

# Checking for required otions
if not options.I:
        parser.error("You must provide a file to process (-i file.pdb)")

#Asign options to variables
INPUT = options.I
optCLEAN = options.C
optWATER = options.W
optRENUM = options.R
optSPLIT = options.S
optTEST = options.T
optPROTABLE = options.P
optCHM = options.M

#Change to the current working dir
path = os.getcwd()
os.chdir(path)

#Original (oldPDB) and new (newPDB) PDB filename 
oldPDB = os.path.splitext(INPUT)[0]


#parse the PDB-file
structure = pdbparse.get_structure( oldPDB , oldPDB+".pdb")
header = pdbparse.get_header()
trailer = pdbparse.get_trailer()

def cleanPDB(model) :
	for chains in model :
		if len(chains) == 0 :
			model.detach_child(chain.id)
		else :
			residues = chains.child_list
			for residue in residues :
				resn = residue.get_resname()
				if resn in AA :
					print "ATOM:", resn

				elif resn in NA :
					print "ATOM:", resn

				elif resn in EN :
					for ENi, ENRi in itertools.izip(EN, ENR) :
						if resn == ENi :
							print "ATOM: %s <---> %s (%s)" % (resn, ENRi, ENi)
							residue.resname = ENRi

				elif resn in HET :
					for HETi, HETIDi in itertools.izip(HET, HETID) :
						#print "HETATM:", resn, residue.id[0]
						if resn == HETi :
							resOLD = residue.id
							residue.id = ( HETIDi, residue.id[1], residue.id[2] )
							print "HETATM: %s ---> Fixed RECORD to HETATM" % (resn)
				else :
					print "INVALD:", residue, "REMOVED"
					chains.detach_child(residue.id)
		if len(chains) == 0 :
			model.detach_child(chain.id)
	writePDB(structure, "clean")

def renumRES(model) :
	for chains in model :
		residues = chains.child_list
		RSIDUEoffset=residues[0].id[1] - 1
		RESIDUEid = 1
		for residue in residues :
			newID = residue.id[1] - RSIDUEoffset
			residue.id = ( residue.id[0], RESIDUEid, residue.id[2] )
                        RESIDUEid = RESIDUEid + 1
	print "Residue OFFSET was -%s" % (RSIDUEoffset)
	writePDB(structure, "renum")

def renumATM(model) :
	print "Not implemented yet"

def splitPDB(model):
	class SelectChain(Select):
		def accept_chain(self, chain) :
			if chain.get_id() == self._chain_id :
				return True
			else:
				return False

	class SelectHETATM(Select):
		def accept_residue(self, res) :
			if res.id[0] in HETID and res.get_full_id()[2] == self._chain_id :
				return True
			else:
				return False
        class SelectMISCRE(Select):
                def accept_residue(self, res) :
                        if res.resname in MISC and res.get_full_id()[2] == self._chain_id :
                                return True
                        else :
                                return False
                                
	for chains in model:
	        chainID = chains.id
		savePROT = oldPDB+"_chain-"+chainID.lower()+"-prot.pdb"
		saveHETA = oldPDB+"_chain-"+chainID.lower()+"-hetatm.pdb"
		saveMISC = oldPDB+"_chain-"+chainID.lower()+"-misc.pdb"

		pdbwrite.set_structure(structure)

		chain = SelectChain()
		chain._chain_id = chains.id

		hetatm = SelectHETATM()
		hetatm._chain_id = chains.id

		miscre = SelectMISCRE()
		miscre._chain_id = chains.id

		pdbwrite.save(saveMISC, miscre)
		pdbwrite.save(saveHETA, hetatm)

		for residue in chains:
			if residue.id[0] in HETID or residue.resname in MISC :
				chains.detach_child(residue.id)
		pdbwrite.save(savePROT, chain)

def listProtRes(model) :
	protonizable = set()
	for chains in model.child_list :
		for residues in chains.child_list :
			if residues.resname in PRES :
				protonizable.add(residues.resname)
	for item in protonizable :
		if item == "ASP" :
			print item, " | Protonated aspartic acid, proton on od2 --> ASPP"
		elif item == "HIS" :
			print item, " | Protonated His --> HSP | neutral HIS, proton on ND1 --> HSD | neutral His, proton on NE2 --> HSE"
		elif item == "GLU" :
			print item, " | Protonated glutamic acid, proton on oe2--> GLUP"
		elif item == "LYS" :
			print item, " | Neutral --> LSN"

def convCHM(model) :
# Pay attention to the whitespaces to make the Atomtypes allway 4 chars long [ATOMNUM] [....] [RESID]!!
# Singel char: [.X..], Two chars: [.XY.], Three chars: [XYZ.], Four chars: [WXYZ]
	for chains in model :
		residues = chains.child_list
		first_residue = chains.child_list[0]
		last_residue = chains.child_list[-1]
		for residue in residues :
			if residue.resname == 'HOH':
				print "HOH(%s) --> %s" % (residue.id[1], "TIP3")
				residue.resname = 'TIP3'
				for atom in residue.child_list :
					if atom.name ==  "O" :
						print "TIP3(%s): O --> OH2" % (residue.id[1])
						residue['O'].fullname = 'OH2 '
			if residue.resname == 'HIS':
				print "HIS(%s) --> %s" % (residue.id[1], HISTIDIN)
				residue.resname = HISTIDIN

			elif residue.resname == 'ILE':
				for atom in residue.child_list :
					if atom.name ==  "CD1" :
						print "ILE(%s): CD1 --> CD" % (residue.id[1])
						residue['CD1'].fullname = ' CD '
					elif atom.name ==  "1HD1" :
						print "ILE(%s): 1HD1 --> HD1" % (residue.id[1])
						residue['1HD1'].fullname = 'HD1 '
					elif atom.name ==  "2HD1" :
						print "ILE(%s): 2HD1 --> HD2" % (residue.id[1])
						residue['2HD1'].fullname = 'HD2 '
					elif atom.name ==  "3HD1" :
						print "ILE(%s): 3HD1 --> HD3" % (residue.id[1])
						residue['3HD1'].fullname = 'HD3 '

			elif residue.resname == 'SER' :
				for atom in residue.child_list :
					if atom.name ==  "1HG" :
						print "SER(%s): 1HG --> HG1" % (residue.id[1])
						residue['1HG'].fullname = 'HG1 '

			elif residue.resname == 'MET' :
				for atom in residue.child_list :
					if atom.name ==  "SE" :
						print "MET(%s): SE --> SD" % (residue.id[1])
						residue['SE'].fullname = ' SD '

			#for atom in first_residue :
				#if atom.name ==  "1HT" :
					#print "Patching first residue: 1HT --> HT1"
					#residue['1HT'].fullname = 'HT1 '
				#elif atom.name ==  "2HT" :
					#print "Patching first residue: 2HT --> HT2"
					#residue['2HT'].fullname = 'HT2 '
				#elif atom.name ==  "3HT" :
					#print "Patching first residue: 3HT --> HT3"
					#residue['3HT'].fullname = 'HT3 '
			#if last_residue.resname != "HOH" :
				#for atom in last_residue :
					#if atom.name ==  "OXT" :
						#print "Patching last residue: OXT --> OT1"
						#residue['OXT'].fullname = 'OT1 '
					#elif atom.name == "O" :
						#print "Patching last residue: O --> OT2"
						#residue['O'].fullname = 'OT2 '
	writePDB(structure, "charmm")

	filename = oldPDB + '-charmm.pdb'
	in_f = open(filename)
	out_f = tempfile.NamedTemporaryFile()
	for line in in_f:
		out_f.write(line.replace('TIP3A', 'TIP3 '))
	os.remove(filename)
	os.link(out_f.name, filename)
	
def writePDB(struc, suffix) :
	pdbwrite.set_structure(struc)
	pdbwrite.save(oldPDB+"-"+suffix+".pdb")

if optCLEAN :
	cleanPDB(structure[0])
if optRENUM == "res" :
	renumRES(structure[0])
if optRENUM == "atm" :
	renumATM(structure[0])
if optSPLIT :
	splitPDB(structure[0])
if optCHM :
	print "converting to CHARMM format"
	convCHM(structure[0])
#else :
#	writePDB(structure)

if optPROTABLE :
	listProtRes(structure[0])

if optTEST :
	for chains in structure[0] :
		residues = chains.child_list
		for residue in residues :
			print residue.id[0]


#a.get_name()       # atom name (spaces stripped, e.g. "CA")
#a.get_id()         # id (equals atom name)
#a.get_coord()      # atomic coordinates
#a.get_bfactor()    # B factor
#a.get_occupancy()  # occupancy
#a.get_altloc()     # alternative location specifie
#a.get_sigatm()     # std. dev. of atomic parameters
#a.get_siguij()     # std. dev. of anisotropic B factor
#a.get_anisou()     # anisotropic B factor
#a.get_fullname()   # atom name (with spaces, e.g. ".CA.")
