#!/usr/bin/env python2.7

import logging, sys, optparse, itertools
from collections import defaultdict, namedtuple
from os.path import join, basename, dirname, isfile

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = optparse.OptionParser("""usage: %prog [options] filename - merge the panther classification PTHR9.0_human and the HPRD classification into the format symbol-main class

HPRD:
uses the most specific class
just ignores the TREMBL proteins (a few hundred)

panther data is in /hive/data/outside/panther/3.3
hprd data is in /hive/data/outside/hprd/081814/
hprd data received from Arun Patil by email, arun@ibioinformatics.org

output:
symbol,description

to load:
create table geneClasses (gene varchar(255), text varchar(30000));
hgLoadSqlTab publications geneClasses geneClasses.tab
""")

parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages") 
#parser.add_option("-k", "--keggDir", dest="keggDir", action="store", help="the KEGG ftp mirror directory on disk, default %default", default="/hive/data/outside/kegg/06032011/ftp.genome.jp/pub/kegg") 
parser.add_option("-s", "--hgncFname", dest="hgncFname", action="store", help="the HGNC tab file on disk, default %default", default="/hive/data/outside/hgnc/111413/hgnc_complete_set.txt") 
parser.add_option("-p", "--pantherFname", dest="pantherFname", action="store", help="the panther tab file on disk, default %default", default="/hive/data/outside/panther/3.3/PTHR9.0_human") 
parser.add_option("", "--hprdFname", dest="hprdFname", action="store", help="the HPRD tab file on disk, default %default", default="/hive/data/outside/hprd/081814/HPRD_molecular_class_081914.txt") 
#parser.add_option("-u", "--uniprotFname", dest="uniprotFname", action="store", help="the uniprot file from the pubs parser, default %default", default="/hive/data/inside/pubs/parsedDbs/uniprot.9606.tab")
#parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
#parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
(options, args) = parser.parse_args()

if options.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)
# ==== FUNCTIONs =====
def lineFileNext(fh, headers=None):
    """ 
        parses tab-sep file with headers as field names 
        yields collection.namedtuples
        strips "#"-prefix from header line
    """
    line1 = fh.readline()
    line1 = line1.strip("\n").strip("#")
    if headers==None:
        headers = line1.split("\t")
        headers = [h.replace(" ", "_") for h in headers]
        headers = [h.replace("(", "") for h in headers]
        headers = [h.replace(")", "") for h in headers]
    Record = namedtuple('tsvRec', headers)

    for line in fh:
        line = line.rstrip("\n")
        fields = line.split("\t")
        try:
            rec = Record(*fields)
        except Exception, msg:
            logging.error("Exception occured while parsing line, %s" % msg)
            logging.error("Filename %s" % fh.name)
            logging.error("Line was: %s" % repr(line))
            logging.error("Does number of fields match headers?")
            logging.error("Headers are: %s" % headers)
            #raise Exception("wrong field count in line %s" % line)
            continue
        # convert fields to correct data type
        yield rec

def resolveFamilies(root, idToMember):
    " add id to member tuple to idToMember dict for all families "
    logging.info("pass2 - families")
    for mEl in root.findall("Model/MoleculeList/Molecule"):
        molId = mEl.attrib["id"]
        molType = mEl.attrib["molecule_type"]
        famMemEl = mEl.find("FamilyMemberList")
        if famMemEl!=None:
            # <FamilyMemberList>
            # <Member member_molecule_idref="200502">
            name = mEl.find("Name").attrib["value"]
            symList = []
            for memEl in famMemEl.findall("Member"):
                #print name, memEl.attrib
                memId = memEl.attrib["member_molecule_idref"]
                if memId not in idToMember:
                    #logging.debug("skipping %s, not defined yet" % memId)
                    continue
                sym = idToMember[memId][-1]
                symList.append(sym)
            #print symList
            symStr = "|".join(symList)
            idToMember[molId] = ("family", name, symStr)


def trySymFromName(nameStr, accToSym):
    # try a few things to get the official symbol of a synonym
    # handles complexes and returns a | sep list for them

    if "/" in nameStr:
        parts = nameStr.split("/")
        syms = []
        for p in parts:
            p = p.strip()
            pSym = trySymFromName(p, accToSym)
            if pSym!="":
                syms.append(pSym)
        if len(syms)!=0:
            return "|".join(syms)

    if nameStr == "":
        nameStr = accTonameStr.get(nameStr, "")
    if nameStr=="":
        nameStr = accTonameStr.get(nameStr.split("-")[0].lower(), "")
    if nameStr=="":
        nameStr = accTonameStr.get(nameStr.split("/")[0].lower(), "")
    if nameStr=="":
        nameStr = accTonameStr.get(nameStr.split()[0].lower(), "")
    # give up
    return ""

def parseXml(filename, accToSym):
    """ parse NCI XML, returns list of tuples, see global "headers" variable
    """
    logging.info(filename)
    rows = []
    tree = et.parse(filename)
    root = tree.getroot()


    source = basename(filename).split(".")[0]
    idToMember = {} # id -> ("complex" or "family" or "gene", name, |-sep list of genes)
    idToSym = {} # id -> sym

    logging.info("pass1")
    skipUps = []
    noSym = set()
    for mEl in root.findall("Model/MoleculeList/Molecule"):
        # <Molecule molecule_type="protein" id="201405">
        # <Name name_type="UP" long_name_type="UniProt" value="P98170" />
        famMemEl = mEl.find("FamilyMemberList")
        # ignore molecules with families
        if famMemEl!=None:
            continue
        molId = mEl.attrib["id"]
        molType = mEl.attrib["molecule_type"]

        names = {}
        for nameEl in mEl.findall("Name"):
            names[nameEl.attrib["name_type"]] = nameEl.attrib["value"]

        # <ComplexComponentList>
        # <ComplexComponent molecule_idref="215568">
        # </ComplexComponentList>
        complCompEls = mEl.findall("ComplexComponentList/ComplexComponent")
        if molType in ["protein", "rna"] or (molType=="complex" and len(complCompEls)==0):
            # weird enough, this can happen: a complex with a protein UP entry... 
            # and no components -> we treat these like a normal protein
            upAcc = None
            sym = ""
            if "UP" in names:
                upAcc = names["UP"]
                upAcc = upAcc.split("-")[0]
                sym = accToSym.get(upAcc, None)
                if sym==None:
                    skipUps.append(upAcc)
                    continue
            nameStr = names["PF"]

            if sym=="":
                sym = trySymFromName(nameStr, accToSym)
            compType = "gene"
            if "|" in sym:
                compType = "complex"

            idToMember[molId] = ("gene", nameStr, sym)
        elif molType=="compound":
            # <Molecule molecule_type="compound" id="201403">
            # <Name name_type="CA" long_name_type="Chemical Abstracts" value="521-18-6" />
            # <Name name_type="PF" long_name_type="preferred symbol" value="DHT" />
            for nameEl in mEl.findall("Name"):
                if nameEl.attrib["name_type"]=="CA":
                    idStr = "compound:CAS"+nameEl.attrib["value"]
                elif nameEl.attrib["name_type"]=="PF":
                    name = nameEl.attrib["value"]
            idToMember[molId] = ("compound", name, idStr)

    logging.info("Could not resolve %d uniprot IDs, out of %d" % (len(skipUps), len(idToMember)))
    logging.info("%d molecules without an official symbol" % (len(noSym)))
    logging.debug("list: %s" %  ",".join(noSym))

    logging.info("pass2 families")
    resolveFamilies(root, idToMember)

    logging.info("pass3 complexes")
    for mEl in root.findall("Model/MoleculeList/Molecule"):
        molId = mEl.attrib["id"]
        molType = mEl.attrib["molecule_type"]
        famMemEl = mEl.find("FamilyMemberList")
        names = {}
        for nameEl in mEl.findall("Name"):
            names[nameEl.attrib["name_type"]] = nameEl.attrib["value"]
        if molType=="complex" and not "UP" in names:
            # <Name name_type="PF" long_name_type="preferred symbol" value="FA complex/FANCD2/Ubiquitin" />
            symList = []
            name = names["PF"]
            for compEl in mEl.findall("ComplexComponentList/ComplexComponent"):
                compId = compEl.attrib["molecule_idref"]
                if compId not in idToMember:
                    # just skip invalid ones
                    continue
                compType, compName, compSym = idToMember[compId]
                if compType=="family":
                    compSym = compSym.replace("|", "_")
                if compSym=="": # some members don't have an offical symbol, skip them
                    continue
                if compSym in symList:
                    continue
                symList.append(compSym)
            #assert(symList!=[]) # we have some complexes without member proteins

            # some are not annotated, we resort to guessing
            if len(symList)==0:
                symStr = trySymFromName(name, accToSym)
            else:
                symStr = "|".join(symList)
            idToMember[molId] = ("complex", name, symStr)
            
    logging.info("pass4 families again")
    resolveFamilies(root, idToMember)

    #for id, tup in idToMember.iteritems():
        #print id, tup

    # create a dict with interactionId -> list of pathways
    intIdToName = {}
    for pwEl in root.findall("Model/PathwayList/Pathway"):
        pwName = pwEl.find("LongName").text
        #print pwName
        pwCompEls = pwEl.findall("PathwayComponentList/PathwayComponent")
        if pwCompEls==None:
            continue
        for pwCompEl in pwCompEls:
            intIdToName[pwCompEl.attrib["interaction_idref"]]=pwName

    # now iterate over interactions and output their components as pairs
    eventId = 0
    skipCount = 0
    for ic in root.findall("Model/InteractionList/Interaction"):
        intId = ic.attrib["id"]
        iType = ic.attrib["interaction_type"]
        if iType in ["pathway", "subnet"]: # refs to complete pathways
            continue
        #print iType
        #assert(iType in ["transcription", "modification", "translocation"])
        src = ic.find("Source").text

        evidList = []
        for evidEl in ic.findall("EvidenceList/Evidence"):
            evidList.append(evidToName[evidEl.text])
        evidStr = "|".join(evidList)

        pmids = []
        for refEl in ic.findall("ReferenceList/Reference"):
            pmids.append(refEl.attrib["pmid"])
        pmidStr = "|".join(pmids)

        edges = defaultdict(set)
        allParts = set()
        for intCompEl in ic.findall("InteractionComponentList/InteractionComponent"):
            idref = intCompEl.attrib["molecule_idref"]
            role = intCompEl.attrib["role_type"]
            member = idToMember.get(idref, None)
            if member==None:
                skipCount += 1
                continue

            assert(role in ["input", "output", "inhibitor", "agent"])
            edges[role].add(member)
            allParts.add(member)

        memCount = 0
        for compType, iList in edges.iteritems():
            memCount+= len(iList)

        for pair in itertools.combinations(allParts,2 ):
            m1, m2 = pair
            subRel = "participates"
            if m1 in edges["agent"]:
                subRel = "activates"
            elif m1 in edges["inhibitor"]:
                subRel = "inhibits"

            pwName = intIdToName[intId]

            row = ["pid%d"%eventId]
            row.extend(m1)
            row.extend(m2)
            row.append(iType)
            row.append(subRel)
            row.append("pid")
            row.append(intId)
            row.append(pwName)
            row.append(pmidStr)
            row.append(evidStr)
            eventId += 1

            yield row

        #print len(edges), memCount, intId, iType, evidStr, pmidStr, edges
    logging.warn("Could not resolve %d members of interactions (probably non-human species)" % skipCount)

def pipeSplitAddAll(string, dict, key, toLower=False):
    " split on pipe and add all values to dict with key "
    for val in string.split("|"):
        if toLower:
            val = val.lower()
        dict[val]=key

def parseUniprot(uniprotTable, accToSym):
    " parse uniprot and return dict with accession or synonym -> symbol "
    logging.info("Parsing %s" % uniprotTable)
    for row in lineFileNext(open(uniprotTable)):
        sym = row.hgncSym.split("|")[0]
        accToSym[sym] = sym # some entries have the HGNC symbol in the xref
        #pipeSplitAddAll(row.entrezGene, accToSym, sym)
        pipeSplitAddAll(row.mainIsoAcc, accToSym, sym)
        pipeSplitAddAll(row.accList, accToSym, sym)
        pipeSplitAddAll(row.acc, accToSym, sym)
        
        # pull out various synonyms: names & symbols
        pipeSplitAddAll(row.geneName, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.geneSynonyms, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.isoNames, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.protFullNames, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.protShortNames, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.protAltFullNames, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.protAltShortNames, accToSym, sym, toLower=True)
        #pipeSplitAddAll(row.ensemblGene, accToSym, sym)
    return accToSym
        
def parseHgnc(fname, addEntrez=False):
    " return a uniprotAcc -> hgnc symbol dict from the HGNC tab-sep file "
    logging.info("Parsing HGNC table")
    upToSym = {}
    skipSyms = set()
    for row in lineFileNext(open(fname)):
        sym = row.Approved_Symbol
        if "withdrawn" in sym or "-AS" in sym:
            continue
        upAcc = row.UniProt_ID_supplied_by_UniProt
        if upAcc=="" or upAcc=="-":
            continue
        if upAcc in upToSym:
            #logging.debug("uniprot accession %s assigned to %s, but already assigned to symbol %s" % (upAcc, sym, upToSym[upAcc]))
            skipSyms.add(sym)
            continue
        entrez = row.Entrez_Gene_ID
        if addEntrez and entrez!="":
            upToSym[entrez] = sym
        upToSym[upAcc] = sym
    logging.info("Skipped these symbols due to duplicate uniprot IDs: %s" % ",".join(skipSyms))
    return upToSym

# ----------- MAIN --------------
#if args==[]:
    #parser.print_help()
    #exit(1)

pantherFilename = options.pantherFname

#uniprotTable = "/hive/data/inside/pubs/parsedDbs/uniprot.9606.tab"
accToSym = parseHgnc(options.hgncFname)

catCounts = defaultdict(int)

symToCats = defaultdict(list)

for line in open(pantherFilename):
    fs = line.rstrip("\n").split('\t')
    idStr = fs[0]
    idStr = idStr.split("|")[-1]
    if not idStr.startswith('UniProt'):
        continue
    acc = idStr.split("=")[-1]
    if acc not in accToSym:
        logging.warn("Not found: %s" % acc)
        continue
    sym = accToSym[acc]

    classes = fs[8].split(";")
    classes = [c.split("#")[0] for c in classes]
    for c in classes:
        catCounts[c]+=1
        symToCats[sym].append(c)

#for cat, count in catCounts.iteritems():
    #print count, cat

pantherClasses = {}
for sym, cats in symToCats.iteritems():
    scoredCats = []
    for c in cats:
        scoredCats.append ( (catCounts[c], c) )
    scoredCats.sort()
    bestCat = scoredCats[0][-1]
    #if len(scoredCats)>1:
        #nextCat = scoredCats[1][-1]
        #bestCat = bestCat+", "+nextCat

    otherStr = ""
    if len(scoredCats)>1:
        otherCats = scoredCats[1:]
        otherCats = [c[-1] for c in otherCats]
        otherStr = ",".join(otherCats)
    #row = [sym, bestCat, otherStr]
    #print "\t".join(row)
    bestCat = bestCat.replace(" molecule", "")
    bestCat = bestCat.replace("G protein", "G-protein")
    bestCat = bestCat.replace(";", ", ")
    pantherClasses[sym] = bestCat
    
# process hprd and write to dict sym -> text
headers = "sym,entrez,desc,refseq,classStr,empty".split(",")
hprdClasses = {}
for row in lineFileNext(open(options.hprdFname), headers=headers):
    if row.classStr.startswith("Unclass"):
        continue
    if row.classStr=="-":
        continue
    classStr = row.classStr
    classStr = classStr.replace("G protein", "G-protein") 
    classStr = classStr.replace(";", ", ") 
    hprdClasses[row.sym] = classStr

# join both dicts
allSyms = set(pantherClasses).union(hprdClasses)
for sym in allSyms:
    parts = []
    if sym in pantherClasses:
        parts.append(pantherClasses.get(sym, ""))
    if sym in hprdClasses:
        parts.append(hprdClasses.get(sym, ""))
    #row = [sym, ", ".join(parts)]
    hprd = hprdClasses.get(sym, "")
    panther = pantherClasses.get(sym, "")
    parts = [hprd, panther]

    # take longer desc if overlap by one word
    if len(set(hprd.lower().split()).intersection(panther.lower().split()))!=0:
        if len(hprd)>len(panther):
            parts=[hprd]
        else:
            parts=[panther]

    parts = [p for p in parts if p!=""]
    #row = [sym, hprd, panther, ", ".join(parts)]
    desc = ", ".join(parts)
    if len(desc)>1:
        desc = desc[0].upper()+"".join(desc[1:])
    if desc=="":
        continue
    row = [sym, desc]
    print "\t".join(row)
