#!/usr/bin/env python2.7

import logging, sys, optparse, itertools
from collections import defaultdict, namedtuple
from os.path import join, basename, dirname, isfile
#import xml.etree.ElementTree as et
import xml.etree.cElementTree as et

evidToName = {
    "IDA" : "assay",
    "IC" : "curator",
    "IGI" : "genetic_interaction",
    "IEP" : "expresssion_pattern",
    "IPI" : "physical_interaction",
    "RCA" : "computational_analysis",
    "IAE" : "array_experiment",
    "IFC" : "functional_complementation",
    "ISS" : "sequence_or_structure",
    "IMP" : "mutant_phenotype",
    "IOS" : "other_species",
    "NIL" : "unknown",
    "RGE" : "reporter_gene"
}

# link to interactions like this:
# http://pid.nci.nih.gov/search/InteractionPage?atomid=204109
# This site is down as of End 2016
# Ndex copied the data, their links look like this:
# http://www.ndexbio.org/#/search?searchType=All&searchString=labels%253Amtor_4pathway
# 

# output file headers
headers = "eventId,causeType,causeName,causeGenes,themeType,themeName,themeGenes,relType,relSubtype,sourceDb,sourceId,sourceDesc,pmids,evidence".split(",")

# ID of event in output file, goes across input files, so global
eventId = 0

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = optparse.OptionParser("""usage: %prog [options] filename - convert NCI PID xml files to tab-sep format.

If filename contains 'BioCarta', adapt the pathway names for biocarta links
""")

parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages") 
#parser.add_option("-k", "--keggDir", dest="keggDir", action="store", help="the KEGG ftp mirror directory on disk, default %default", default="/hive/data/outside/kegg/06032011/ftp.genome.jp/pub/kegg") 
parser.add_option("-s", "--hgncFname", dest="hgncFname", action="store", help="the HGNC tab file on disk, default %default", default="/hive/data/outside/hgnc/111413/hgnc_complete_set.txt") 
parser.add_option("-u", "--uniprotFname", dest="uniprotFname", action="store", help="the uniprot file from the pubs parser, default %default", default="/hive/data/inside/pubs/parsedDbs/uniprot.9606.tab")
#parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
#parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
(options, args) = parser.parse_args()

if options.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)
# ==== FUNCTIONs =====
def lineFileNext(fh):
    """ 
        parses tab-sep file with headers as field names 
        yields collection.namedtuples
        strips "#"-prefix from header line
    """
    line1 = fh.readline()
    line1 = line1.strip("\n").strip("#")
    headers = line1.split("\t")
    headers = [h.replace(" ", "_") for h in headers]
    headers = [h.replace("(", "") for h in headers]
    headers = [h.replace(")", "") for h in headers]
    Record = namedtuple('tsvRec', headers)

    for line in fh:
        line = line.rstrip("\n")
        fields = line.split("\t")
        try:
            rec = Record(*fields)
        except Exception, msg:
            logging.error("Exception occured while parsing line, %s" % msg)
            logging.error("Filename %s" % fh.name)
            logging.error("Line was: %s" % repr(line))
            logging.error("Does number of fields match headers?")
            logging.error("Headers are: %s" % headers)
            #raise Exception("wrong field count in line %s" % line)
            continue
        # convert fields to correct data type
        yield rec

def resolveFamilies(root, idToMember):
    " add id to member tuple to idToMember dict for all families "
    logging.info("pass2 - families")
    for mEl in root.findall("Model/MoleculeList/Molecule"):
        molId = mEl.attrib["id"]
        molType = mEl.attrib["molecule_type"]
        famMemEl = mEl.find("FamilyMemberList")
        if famMemEl!=None:
            # <FamilyMemberList>
            # <Member member_molecule_idref="200502">
            name = mEl.find("Name").attrib["value"]
            symList = []
            for memEl in famMemEl.findall("Member"):
                #print name, memEl.attrib
                memId = memEl.attrib["member_molecule_idref"]
                if memId not in idToMember:
                    #logging.debug("skipping %s, not defined yet" % memId)
                    continue
                sym = idToMember[memId][-1]
                symList.append(sym)
            #print symList
            symStr = "|".join(symList)
            idToMember[molId] = ("family", name, symStr)


def trySymFromName(nameStr, accToSym):
    # try a few things to get the official symbol of a synonym
    # handles complexes and returns a | sep list for them

    if "/" in nameStr:
        parts = nameStr.split("/")
        syms = []
        for p in parts:
            p = p.strip()
            pSym = trySymFromName(p, accToSym)
            if pSym!="":
                syms.append(pSym)
        if len(syms)!=0:
            return "|".join(syms)

    if nameStr == "":
        nameStr = accTonameStr.get(nameStr, "")
    if nameStr=="":
        nameStr = accTonameStr.get(nameStr.split("-")[0].lower(), "")
    if nameStr=="":
        nameStr = accTonameStr.get(nameStr.split("/")[0].lower(), "")
    if nameStr=="":
        nameStr = accTonameStr.get(nameStr.split()[0].lower(), "")
    # give up
    return ""

def parseXml(filename, accToSym):
    """ parse NCI XML, returns list of tuples, see global "headers" variable
    """
    logging.info(filename)
    rows = []
    tree = et.parse(filename)
    root = tree.getroot()


    source = basename(filename).split(".")[0]
    idToMember = {} # id -> tuple ("complex" or "family" or "gene", descriptive name, |-sep list of gene symbols)
    idToSym = {} # id -> sym

    logging.info("pass1")
    skipSyms = []
    noSym = set()
    for mEl in root.findall("Model/MoleculeList/Molecule"):
        # <Molecule molecule_type="protein" id="201405">
        # <Name name_type="UP" long_name_type="UniProt" value="P98170" />
        famMemEl = mEl.find("FamilyMemberList")
        # ignore molecules with families
        if famMemEl!=None:
            continue
        molId = mEl.attrib["id"]
        molType = mEl.attrib["molecule_type"]

        names = {}
        for nameEl in mEl.findall("Name"):
            names[nameEl.attrib["name_type"]] = nameEl.attrib["value"]

        # <ComplexComponentList>
        # <ComplexComponent molecule_idref="215568">
        # </ComplexComponentList>
        complCompEls = mEl.findall("ComplexComponentList/ComplexComponent")
        if molType in ["protein", "rna"] or (molType=="complex" and len(complCompEls)==0):
            # weird enough, this can happen: a complex with a protein UP entry... 
            # and no components -> we treat these like a normal protein
            sym = ""
            if "UP" in names:
                upAcc = names["UP"]
                upAcc = upAcc.split("-")[0]
                sym = accToSym.get(upAcc, "")
            elif "OF" in names:
                sym = names["OF"]
            elif "LL" in names:
                entrezId = names["LL"]
                sym = accToSym.get(entrezId, "")

            if "PF" in names:
                nameStr = names["PF"]
            elif "AS" in names:
                nameStr = names["AS"]
            else:
                nameStr = ""

            if sym=="":
                skipSyms.append("/".join(names.values()))
                continue


            if sym=="":
                sym = trySymFromName(nameStr, accToSym)
            compType = "gene"
            if "|" in sym:
                compType = "complex"

            idToMember[molId] = ("gene", nameStr, sym)
        elif molType=="compound":
            # <Molecule molecule_type="compound" id="201403">
            # <Name name_type="CA" long_name_type="Chemical Abstracts" value="521-18-6" />
            # <Name name_type="PF" long_name_type="preferred symbol" value="DHT" />
            name = ""
            sym = ""
            if "CA" in names:
                sym = "compound:CAS"+names["CA"]
            if "PF" in names:
                name = names["PF"]
            if "AS" in names:
                name = names["AS"]
            idToMember[molId] = ("compound", name, sym)

    logging.info("Could not resolve %d uniprot IDs, out of %d (debug-mode shows them)" % \
        (len(skipSyms), len(idToMember)))
    logging.debug("Unresolvable symbols: %s" %  ",".join(noSym))
    logging.info("%d molecules without an official symbol" % (len(noSym)))

    logging.info("pass2 families")
    resolveFamilies(root, idToMember)

    logging.info("pass3 complexes")
    for mEl in root.findall("Model/MoleculeList/Molecule"):
        molId = mEl.attrib["id"]
        molType = mEl.attrib["molecule_type"]
        famMemEl = mEl.find("FamilyMemberList")
        names = {}
        for nameEl in mEl.findall("Name"):
            names[nameEl.attrib["name_type"]] = nameEl.attrib["value"]

        if molType=="complex":
            complCompEls = mEl.findall("ComplexComponentList/ComplexComponent")
            # it's only a complex if we have members
            if len(complCompEls)==0:
                logging.debug("Complex without members: %s" % "|".join(names.values()))
                continue

            # <Name name_type="PF" long_name_type="preferred symbol" value="FA complex/FANCD2/Ubiquitin" />
            name = ""
            if "PF" in names:
                name = names["PF"]
            else:
                name = names["AS"]

            symList = []
            for compEl in mEl.findall("ComplexComponentList/ComplexComponent"):
                compId = compEl.attrib["molecule_idref"]
                if compId not in idToMember:
                    # just skip invalid ones
                    continue
                compType, compName, compSym = idToMember[compId]
                if compType=="family":
                    compSym = compSym.replace("|", "_")
                if compSym=="": # some members don't have an offical symbol, skip them
                    continue
                if compSym in symList:
                    continue
                symList.append(compSym)
            #assert(symList!=[]) # we have some complexes without member proteins

            # some are not annotated, we resort to guessing
            if len(symList)==0:
                symStr = trySymFromName(name, accToSym)
            else:
                symStr = "|".join(symList)
            idToMember[molId] = ("complex", name, symStr)
            
    logging.info("pass4 families again")
    resolveFamilies(root, idToMember)

    #for id, tup in idToMember.iteritems():
        #print id, tup

    # create a dict with interactionId -> list of pathway names
    intIdToName = {}
    # this is only a temporary fix for the NDex database: they do not have interaction pages anymore, unlike the original PID
    intIdToShortName = {}
    for pwEl in root.findall("Model/PathwayList/Pathway"):
        pwName = pwEl.find("LongName").text
        pwShortName = pwEl.find("ShortName").text
        if "BioCarta" in filename:
            # bad hack, but uppercase is lost in the xml
            pwShortName = "h_"+pwShortName.replace("pathway", "Pathway")
        #print pwName
        pwCompEls = pwEl.findall("PathwayComponentList/PathwayComponent")
        if pwCompEls==None:
            continue
        for pwCompEl in pwCompEls:
            intIdToName[pwCompEl.attrib["interaction_idref"]]=pwName
            intIdToShortName[pwCompEl.attrib["interaction_idref"]]=pwShortName

    # now iterate over interactions and output their components as pairs
    global eventId
    skipCount = 0
    for ic in root.findall("Model/InteractionList/Interaction"):
        dbName = "pid"
        if "BioCarta" in filename:
            dbName = "biocarta"

        intId = ic.attrib["id"]
        iType = ic.attrib["interaction_type"]
        if iType in ["pathway", "subnet"]: # refs to complete pathways
            continue
        #print iType
        #assert(iType in ["transcription", "modification", "translocation"])
        src = ic.find("Source").text

        # prepr evidence and pmid strings
        evidList = []
        for evidEl in ic.findall("EvidenceList/Evidence"):
            evidList.append(evidToName[evidEl.text])
        evidStr = "|".join(evidList)

        pmids = []
        for refEl in ic.findall("ReferenceList/Reference"):
            pmids.append(refEl.attrib["pmid"])
        pmidStr = "|".join(pmids)

        edges = defaultdict(set)
        allParts = set()
        for intCompEl in ic.findall("InteractionComponentList/InteractionComponent"):
            idref = intCompEl.attrib["molecule_idref"]
            role = intCompEl.attrib["role_type"]
            member = idToMember.get(idref, None)
            if member==None:
                skipCount += 1
                continue

            assert(role in ["input", "output", "inhibitor", "agent"])
            edges[role].add(member)
            allParts.add(member)

        memCount = 0
        for compType, iList in edges.iteritems():
            memCount+= len(iList)

        # now create the pairs
        for pair in itertools.combinations(allParts,2 ):
            m1, m2 = pair # m1 and m2 are tuples (type, name, symbolList)
            subRel = "collaborate"
            if m1 in edges["agent"]:
                subRel = "activates"
            elif m1 in edges["inhibitor"]:
                subRel = "inhibits"
            elif m1 in edges["output"] and (m2 in edges["agent"] or m2 in edges["inhibitor"]):
                # we already have act/inhibitor -> output covered, no need to add the reverse
                continue

            pwName = intIdToName[intId]

            row = ["pid%d"%eventId]
            row.extend(m1)
            row.extend(m2)
            row.append(iType)
            row.append(subRel)
            row.append(dbName)
            #row.append(intId)
            row.append(intIdToShortName[intId])
            row.append(pwName)
            row.append(pmidStr)
            row.append(evidStr)
            eventId += 1

            yield row

        #print len(edges), memCount, intId, iType, evidStr, pmidStr, edges
    logging.warn("Resolved %d members. Could not resolve %d members in interactions" % (len(idToMember), skipCount))

def pipeSplitAddAll(string, dict, key, toLower=False):
    " split on pipe and add all values to dict with key "
    for val in string.split("|"):
        if toLower:
            val = val.lower()
        dict[val]=key

def parseUniprot(uniprotTable, accToSym):
    " parse uniprot and return dict with accession or synonym -> symbol "
    logging.info("Parsing %s" % uniprotTable)
    for row in lineFileNext(open(uniprotTable)):
        sym = row.hgncSym.split("|")[0]
        accToSym[sym] = sym # some entries have the HGNC symbol in the xref
        pipeSplitAddAll(row.entrezGene, accToSym, sym)
        pipeSplitAddAll(row.mainIsoAcc, accToSym, sym)
        pipeSplitAddAll(row.accList, accToSym, sym)
        pipeSplitAddAll(row.acc, accToSym, sym)
        
        # pull out various synonyms: names & symbols
        pipeSplitAddAll(row.geneName, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.geneSynonyms, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.isoNames, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.protFullNames, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.protShortNames, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.protAltFullNames, accToSym, sym, toLower=True)
        pipeSplitAddAll(row.protAltShortNames, accToSym, sym, toLower=True)
        #pipeSplitAddAll(row.ensemblGene, accToSym, sym)
    return accToSym
        
def parseHgnc(fname, addEntrez=False):
    " return a uniprotAcc -> hgnc symbol dict from the HGNC tab-sep file "
    upToSym = {}
    skipSyms = set()
    for row in lineFileNext(open(fname)):
        sym = row.Approved_Symbol
        if "withdrawn" in sym or "-AS" in sym:
            continue
        upAcc = row.UniProt_ID_supplied_by_UniProt
        if upAcc=="" or upAcc=="-":
            continue
        if upAcc in upToSym:
            #logging.debug("uniprot accession %s assigned to %s, but already assigned to symbol %s" % (upAcc, sym, upToSym[upAcc]))
            skipSyms.add(sym)
            continue
        entrez = row.Entrez_Gene_ID
        if addEntrez and entrez!="":
            upToSym[entrez] = sym
        upToSym[upAcc] = sym
    logging.info("Skipped these symbols due to duplicate uniprot IDs: %s" % ",".join(skipSyms))
    return upToSym

# ----------- MAIN --------------
if args==[]:
    parser.print_help()
    exit(1)

filenames = args

#uniprotTable = "/hive/data/inside/pubs/parsedDbs/uniprot.9606.tab"
accToSym = parseHgnc(options.hgncFname)
accToSym = parseUniprot(options.uniprotFname, accToSym)

print "#"+"\t".join(headers)

for filename in filenames:
    logging.debug(filename)
    rows = parseXml(filename, accToSym)
    for row in rows:
        l = u"\t".join(row)
        print l.encode("utf8")
