#!/usr/bin/env python2.7

import logging, sys, optparse
from collections import defaultdict, namedtuple
from os.path import join, basename, dirname, isfile
import xml.etree.ElementTree as et

intTypeNames = {
    "Arrow" : "activation",
    "TBar" : "inhibition"
}

# global var to identify reactions
eventId = 0

# output file headers
headers = "eventId,causeType,causeName,causeGenes,themeType,themeName,themeGenes,relType,relSubtype,sourceDb,sourceId,sourceDesc,pmids".split(",")

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = optparse.OptionParser("""usage: %prog [options] filename - parse Wikipathways GPML files to tab-sep format of gene/family/complex + gene/family/complex + interaction type + subtype 

links to pathways look like this: http://www.wikipathways.org/index.php/Pathway:WP1991
""")

parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages") 
#parser.add_option("-k", "--keggDir", dest="keggDir", action="store", help="the KEGG ftp mirror directory on disk, default %default", default="/hive/data/outside/kegg/06032011/ftp.genome.jp/pub/kegg") 
#parser.add_option("-s", "--hgncFname", dest="hgncFname", action="store", help="the HGNC tab file on disk, default %default", default="/hive/data/outside/hgnc/111413/hgnc_complete_set.txt") 
#parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
#parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
(options, args) = parser.parse_args()

if options.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)
# ==== FUNCTIONs =====
def lineFileNext(fh):
    """ 
        parses tab-sep file with headers as field names 
        yields collection.namedtuples
        strips "#"-prefix from header line
    """
    line1 = fh.readline()
    line1 = line1.strip("\n").strip("#")
    headers = line1.split("\t")
    Record = namedtuple('tsvRec', headers)

    for line in fh:
        line = line.rstrip("\n")
        fields = line.split("\t")
        try:
            rec = Record(*fields)
        except Exception, msg:
            logging.error("Exception occured while parsing line, %s" % msg)
            logging.error("Filename %s" % fh.name)
            logging.error("Line was: %s" % repr(line))
            logging.error("Does number of fields match headers?")
            logging.error("Headers are: %s" % headers)
            #raise Exception("wrong field count in line %s" % line)
            continue
        # convert fields to correct data type
        yield rec

def symFromPoint(pe, idToGene, idToGroup, groupToSyms):
    " given a <Point> element with a graphRef, return the symbol for it (can be a group or a gene) "
    gi = pe.attrib["GraphRef"]
    if gi in idToGene:
        return idToGene[gi]
    elif gi in idToGroup:
        group = idToGroup[gi]
        return "/".join(groupToSyms[group])
    else:
        logging.debug("Cannot find a symbol for graphRef %s" % gi)
        return None

def parseGpml(filename, accToSym):
    """ parse kegg, returns list of tuples:
    (gene1, complex1, family1, gene2, complex2, family2, interactionType, interactionSubtype 
    """
    logging.info(filename)
    rows = []
    tree = et.parse(filename)
    root = tree.getroot()

    # remove all namespaces
    for el in root.getiterator():
        if el.tag[0] == '{':
            el.tag = el.tag.split('}', 1)[1]


    #comment = root.find("Comment").text

    global eventId

    pathway = basename(filename).split("_")[-2]
    # DNA Damage Response
    title = root.attrib["Name"]
    idToGroup = {}
    groupToType = {} # graphId -> "complex" or "family"

    for g in root.findall("Group"):
        # <Group GroupId="a5395" GraphId="a24bf" Style="Complex" />
        #if g.attrib.get("Style", None)=="Complex" and g.attrib.get("GraphId", None)!=None:
        if g.attrib.get("GraphId", None)!=None:
            idToGroup[g.attrib["GraphId"]] = g.attrib["GroupId"]
            style = g.attrib.get("Style", "Group")
            if style=="Complex":
                groupType = "complex"
            elif style=="Group":
                groupType = "family" # XX careful - A. Pico email ?
            elif style=="Pathway":
                #groupType = "pathway"
                logging.warn("A group is a pathway?")
                continue
            else:
                assert(False) # unknown group
            groupToType[g.attrib["GraphId"]] = groupType

    # get a mapping from reference id to PMID
    # <bp:PublicationXref xmlns:bp="http://www.biopax.org/release/biopax-level3.owl#" xmlns:rdf="http://www.w3.org/1999/02/22-rdf rdf:id="fba">
    #       <bp:ID rdf:datatype="http://www.w3.org/2001/XMLSchema#string">20691899</bp:ID>
    #       <bp:DB rdf:datatype="http://www.w3.org/2001/XMLSchema#string">PubMed</bp:DB>
    #       <bp:TITLE rdf:datatype="http://www.w3.org/2001/XMLSchema#string">Direct reprogramming of fibroblasts into functional card
    #       <bp:SOURCE rdf:datatype="http://www.w3.org/2001/XMLSchema#string">Cell</bp:SOURCE>
    #       <bp:YEAR rdf:datatype="http://www.w3.org/2001/XMLSchema#string">2010</bp:YEAR>
    #       <bp:AUTHORS rdf:datatype="http://www.w3.org/2001/XMLSchema#string">Ieda M</bp:AUTHORS>
    # </bp:PublicationXref>
    idToPmid = {}
    for g in root.findall("Biopax/PublicationXref"):
        refId = g.attrib["{http://www.w3.org/1999/02/22-rdf-syntax-ns#}id"]
        if g.find("DB").text!="PubMed":
            continue
        pmid = g.find("ID").text
        if pmid==None:
            logging.debug("PMID is none?")
            continue
        idToPmid[refId] = pmid

    idToGene = {}
    groupToSyms = defaultdict(list)

    for c in root:
        #print c.tag
        #   <DataNode TextLabel="TLK2" GraphId="bd01f" Type="GeneProduct" GroupRef="a5395">
        #     <Graphics CenterX="1479.0" CenterY="420.0" Width="80.0" Height="20.0" ZOrder="32768" FontSize="10" Valign="Middle" />
        #     <Xref Database="Ensembl" ID="ENSG00000146872" />
        #   </DataNode>
        if c.tag=="DataNode":
            if "Type" not in c.attrib:
                continue
            dataType = c.attrib["Type"]
            if dataType=="GeneProduct":
                sym = c.attrib["TextLabel"]
                # often the symbol is not the text label
                if sym not in accToSym:
                    xref = c.find("Xref")
                    if xref==None:
                        logging.debug("datanode without xref element, %s" % c)
                        continue
                    acc = xref.attrib["ID"]
                    db = xref.attrib["Database"]
                    if db=="" and acc=="": # occurs very often, cut/paste error
                        continue
                    sym = accToSym.get(acc, None)
                    if sym==None:
                        logging.debug("Could not resolve accession %s, db %s" % (acc, db))
                        continue
                    
                groupRef = c.attrib.get("GroupRef", None)
                if "GraphId" not in c.attrib: # ??
                    continue
                graphId = c.attrib["GraphId"]
                idToGene[graphId] = sym
                if groupRef!=None:
                    groupToSyms[groupRef].append(sym)
                #print dataType, sym, groupRef, graphId, groupRef

            
        #<Interaction GraphId="a2011">
        #    <BiopaxRef>e6f</BiopaxRef>
        #    <Graphics ConnectorType="Curved" ZOrder="12288" LineThickness="1.0">
        #      <Point X="594.0" Y="120.0" GraphRef="d81c9" RelX="-1.0" RelY="0.0" />
        #      <Point X="524.0" Y="191.0" />
        #      <Point X="377.0" Y="262.0" GraphRef="d9b7f" RelX="1.0" RelY="0.0" ArrowHead="Arrow" />
        #    </Graphics>
        #    <Xref Database="" ID="" />
        #  </Interaction>

        # <Interaction GraphId="c5561">
        #   <Graphics ConnectorType="Elbow" ZOrder="12288" LineThickness="1.0">
        #     <Point X="1585.0" Y="354.0" GraphRef="ac4da" RelX="-1.0" RelY="0.0" />
        #     <Point X="1505.0" Y="377.0" GraphRef="a24bf" RelX="0.5" RelY="-1.0" ArrowHead="TBar" />
        #   </Graphics>
        #   <Xref Database="" ID="" />
        # </Interaction>
        elif c.tag=="Interaction":
            #     <relation entry1="48" entry2="35" type="PPrel">  
            # <subtype name="activation" value="--&gt;"/>  
            # </relation>  

            #print c.attrib.get("GraphId", None)
            graphEl = c.find("Graphics")
            themeSyms = []
            causeSyms = []
            relType = "unknown"
            head = None
            themeTypes = []
            causeTypes = []
            for pe in graphEl.findall("Point"):
                if not "GraphRef" in pe.attrib:
                    continue
                sym = symFromPoint(pe, idToGene, idToGroup, groupToSyms)
                if sym==None:
                    logging.debug("Could not resolve %s to symbol" % pe.attrib)
                    continue

                graphRef = pe.attrib["GraphRef"]
                graphType = groupToType.get(graphRef, "gene")

                if "ArrowHead" in pe.attrib:
                    themeSyms.append(sym)
                    head = pe.attrib["ArrowHead"]
                    if head!=None:
                        if head=="Arrow":
                            relType = "activation"
                        elif head=="TBar":
                            relType = "inhibition"
                        else:
                            relType = head
                    themeTypes.append(graphType)
                else:
                    causeSyms.append(sym)
                    causeTypes.append(graphType)

            if len(themeSyms)==1 and len(causeSyms)==1:
                theme = themeSyms[0]
                cause = causeSyms[0]
            elif len(themeSyms)==0 and len(causeSyms)==2:
                head = "binding"
                cause = causeSyms[0]
                theme = causeSyms[1]
            else:
                logging.debug("Wrong number of interaction partners %s: %s, %s" % (c.attrib, themeSyms, causeSyms))
                continue

            if "/" in cause:
                if len(set(causeTypes))!=1:
                    logging.warn("more than one type of the causes: %s" % str(causeTypes))
                causeType  = causeTypes[0]
                causeGenes = cause.replace("/", "|")
            else:
                causeGenes = cause
                causeType = "gene"

            if "/" in theme:
                if len(set(themeTypes))!=1:
                    logging.warn("more than one type of the themes: %s" % str(themeTypes))
                if len(themeTypes)==0:
                    logging.warn("no type for group %s" % str(theme))
                    themeType = "family"
                else:
                    themeType  = themeTypes[0]
                #themeType = "complex"
                themeGenes = theme.replace("/", "|")
            else:
                themeGenes = theme
                themeType = "gene"

            if themeGenes=="" or causeGenes=="":
                continue
            if themeGenes==causeGenes:
                continue

            relSubtype = ""
            eventId += 1
            eventIdStr = "wiki%d" % eventId

            # get all PMIDs
            pmids = []
            for refEl in c.findall("BiopaxRef"):
                refId = refEl.text
                if not refId in idToPmid:
                    logging.debug("BiopaxRef %s has no PMID" % refId)
                    continue
                pmids.append(idToPmid[refId])

            row = [eventIdStr, causeType, "", causeGenes, themeType, "", themeGenes, relType, relSubtype, "wikipathways", pathway, title, "|".join(pmids)]
            rows.append(row)

            #print row

    return rows

def pipeSplitAddAll(string, dict, key):
    " split on pipe and add all values to dict with key "
    for val in string.split("|"):
        dict[val]=key

def parseUniprot(uniprotTable):
    " parse uniprot and return dict with accession -> symbol "
    logging.info("Parsing %s" % uniprotTable)
    accToSym = {}
    for row in lineFileNext(open(uniprotTable)):
        sym = row.hgncSym.split("|")[0]
        accToSym[sym] = sym # some entries have the HGNC symbol in the xref
        pipeSplitAddAll(row.entrezGene, accToSym, sym)
        pipeSplitAddAll(row.mainIsoAcc, accToSym, sym)
        pipeSplitAddAll(row.accList, accToSym, sym)
        pipeSplitAddAll(row.acc, accToSym, sym)
        pipeSplitAddAll(row.ensemblGene, accToSym, sym)
    return accToSym
        
# ----------- MAIN --------------
if args==[]:
    parser.print_help()
    exit(1)

filenames = args

uniprotTable = "/hive/data/inside/pubs/parsedDbs/uniprot.9606.tab"
accToSym = parseUniprot(uniprotTable)

print "#"+"\t".join(headers)
for filename in filenames:
    logging.debug(filename)
    rows = parseGpml(filename, accToSym)
    for row in rows:
        l = u"\t".join(row)
        print l.encode("utf8")
