#
#  Copyright (C) 2014 Novartis Institutes for BioMedical Research
#
#   @@ All Rights Reserved @@
#  This file is part of the RDKit.
#  The contents are covered by the terms of the BSD license
#  which is included in the file license.txt, found at the root
#  of the RDKit source tree.
#
from __future__ import print_function
from rdkit import Chem
import numpy
from numpy import linalg

def _generateMoments(mol,confId=-1,weights=None):
  if weights is None:
    weights = [1]*mol.GetNumAtoms()
  conf = mol.GetConformer(confId)
  coords = numpy.array([conf.GetAtomPosition(x) for x in range(mol.GetNumAtoms())])
  origin = sum(weights[i]*coords[i] for i in range(mol.GetNumAtoms()))/sum(weights)
  coords -= origin

  cov=numpy.zeros((3,3),numpy.double)
  for i,coord in enumerate(coords):
    cov[0,0] += weights[i]*(coord[1]**2 + coord[2]**2)
    cov[1,1] += weights[i]*(coord[0]**2 + coord[2]**2)
    cov[2,2] += weights[i]*(coord[0]**2 + coord[1]**2)
    cov[0,1] -= weights[i]*coord[0]*coord[1]
    cov[0,2] -= weights[i]*coord[0]*coord[2]
    cov[1,2] -= weights[i]*coord[1]*coord[2]
  cov[1,0] = cov[0,1]
  cov[2,0] = cov[0,2]
  cov[2,1] = cov[1,2]

  vals = linalg.eigvals(cov)
  vals.sort()
  return vals

def _generate3D(mol):
  from rdkit_Novartis.Interfaces import Generate3D
  smi = Chem.MolToSmiles(mol,True)
  d = Generate3D.build3d({0:dict(smiles=smi)},'smiles','ctab',False)
  d = d[0]
  if d.get('status','FAIL')!='OK':
    raise ValueError('could not generate 3D structure')
  res = Chem.MolFromMolBlock(d['ctab'])
  return res
  

def _CalcPMIDescriptors(mol,confId=-1,weights=None,force=False):
  """
  returns pmi1,pmi2,pmi3,npr1,npr2
  """
  if not force and hasattr(mol,'_PMIVals'):
    res = mol._PMIVals
  else:
    if mol.GetNumConformers()==0:
      molH=Chem.AddHs(_generate3D(mol),addCoords=True)
    else:
      molH = Chem.AddHs(mol,addCoords=True)
    if weights is None:
      weights = [x.GetMass() for x in molH.GetAtoms()]
    moms = _generateMoments(molH,confId=confId,weights=weights)
    res = moms[0],moms[1],moms[2],moms[0]/moms[2],moms[1]/moms[2]
    mol._PMIVals = res
  return res
                       
def CalcPMI1(mol,confId=-1,force=False):
  return _CalcPMIDescriptors(mol,confId=confId,force=force)[0]
def CalcPMI2(mol,confId=-1,force=False):
  return _CalcPMIDescriptors(mol,confId=confId,force=force)[1]
def CalcPMI3(mol,confId=-1,force=False):
  return _CalcPMIDescriptors(mol,confId=confId,force=force)[2]
def CalcNPR1(mol,confId=-1,force=False):
  return _CalcPMIDescriptors(mol,confId=confId,force=force)[3]
def CalcNPR2(mol,confId=-1,force=False):
  return _CalcPMIDescriptors(mol,confId=confId,force=force)[4]


if __name__=='__main__':
  sdfd="""XX_0000000001_001_001_001
     RDKit          3D

  5  4  0  0  0  0  0  0  0  0999 V2000
   -0.0372    1.5779    0.0423 C   0  0  0  0  0  0  0  0  0  0  0  0
   -0.0107    0.0264    0.0242 C   0  0  0  0  0  0  0  0  0  0  0  0
    0.6771   -0.4966   -1.1839 N   0  0  0  0  0  0  0  0  0  0  0  0
    0.7202   -1.9805   -1.2328 C   0  0  0  0  0  0  0  0  0  0  0  0
    1.4448   -2.4914   -2.5064 C   0  0  0  0  0  0  0  0  0  0  0  0
  1  2  1  0
  2  3  1  0
  3  4  1  0
  4  5  1  0
M  CHG  1   3   1
M  END

>  <SMILES_USERINPUT>
CCNCC

>  <STATUS>
OK

>  <COMMENT>
correct SMILES  3D structure successfully generated (#structures: 1; Protonation Abundance: 99.95%; Tautomerisation Abundance: 100%)

>  <OUTPUT_moe:npr2>
0.94636673

>  <OUTPUT_moe:npr1>
0.11207375

>  <OUTPUT_moe:pmi2>
251.33347

>  <OUTPUT_moe:pmi1>
29.764236

>  <OUTPUT_moe:pmi3>
265.57724

>  <OUTPUT_moe:pmi>
273.33746

>  <OUTPUT_moe:pmiY>
167.30069

>  <OUTPUT_moe:pmiZ>
75.545876

>  <OUTPUT_moe:pmiX>
30.490913


$$$$
XX_0000000002_001_001_001
     RDKit          3D

 12 12  0  0  0  0  0  0  0  0999 V2000
    0.1906    1.4156    0.2263 C   0  0  0  0  0  0  0  0  0  0  0  0
   -0.0330   -0.1087    0.0412 C   0  0  0  0  0  0  0  0  0  0  0  0
    0.7670   -0.7169   -0.9973 O   0  0  0  0  0  0  0  0  0  0  0  0
    2.1979   -0.6480   -0.8048 C   0  0  0  0  0  0  0  0  0  0  0  0
    2.8054    0.6364   -1.4304 C   0  0  0  0  0  0  0  0  0  0  0  0
    3.8673    0.1208   -2.4291 C   0  0  0  0  0  0  0  0  0  0  0  0
    3.3100   -1.1675   -2.8908 N   0  0  0  0  0  0  0  0  0  0  0  0
    2.8358   -1.7942   -1.6373 C   0  0  0  0  0  0  0  0  0  0  0  0
   -1.5059   -0.3843   -0.3459 C   0  0  0  0  0  0  0  0  0  0  0  0
   -1.7966    0.2344   -1.5223 F   0  0  0  0  0  0  0  0  0  0  0  0
   -1.6904   -1.7245   -0.4953 F   0  0  0  0  0  0  0  0  0  0  0  0
   -2.3503    0.0713    0.6196 F   0  0  0  0  0  0  0  0  0  0  0  0
  1  2  1  0
  2  3  1  0
  2  9  1  0
  3  4  1  0
  4  8  1  0
  4  5  1  0
  5  6  1  0
  6  7  1  0
  7  8  1  0
  9 10  1  0
  9 11  1  0
  9 12  1  0
M  CHG  1   7   1
M  END

>  <SMILES_USERINPUT>
C1CNCC1OC(C(F)(F)F)C

>  <STATUS>
OK

>  <COMMENT>
correct SMILES  3D structure successfully generated (#structures: 4; Protonation Abundance: 99.88%; Tautomerisation Abundance: 100%)

>  <OUTPUT_moe:npr2>
0.91831785

>  <OUTPUT_moe:npr1>
0.22067128

>  <OUTPUT_moe:pmi2>
1084.2975

>  <OUTPUT_moe:pmi1>
260.55609

>  <OUTPUT_moe:pmi3>
1180.7432

>  <OUTPUT_moe:pmi>
1262.7983

>  <OUTPUT_moe:pmiY>
167.13344

>  <OUTPUT_moe:pmiZ>
209.69215

>  <OUTPUT_moe:pmiX>
885.97278


$$$$
"""
  suppl = Chem.SDMolSupplier()
  suppl.SetData(sdfd)
  ms = [x for x in suppl]
  print(_CalcPMIDescriptors(ms[0]))
  print(_CalcPMIDescriptors(ms[1]))
