#!/usr/bin/env python2.7

import logging, sys, optparse
from collections import defaultdict
from os.path import join, basename, dirname, isfile
import xml.etree.ElementTree as et

# see http://www.kegg.jp/kegg/xml/docs/ for this
intTypeNames = {
    "PPrel" : "protein-protein",
    "ECrel" : "enzyme-enzyme",
    "GErel" : "gene expression",
    "PCrel" : "protein-compound"
}

# global var to identify reactions
keggId = 0

# output file headers
headers = "eventId,causeType,causeName,causeGenes,themeType,themeName,themeGenes,relType,relSubtype,sourceDb,sourceId,sourceDesc,pmids".split(",")

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = optparse.OptionParser("""usage: %prog [options] filename - parse KEGG kgml files to tab-sep format of gene/family/complex + gene/family/complex + interaction type + subtype 

A copy of these XML files is in /hive/data/outside/kegg/06062011/ftp.genome.jp/pub/kegg/release/current/kgml/non-metabolic/organisms/hsa
""")

parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages") 
parser.add_option("-k", "--keggDir", dest="keggDir", action="store", help="the KEGG ftp mirror directory on disk, default %default", default="/hive/data/outside/kegg/06032011/ftp.genome.jp/pub/kegg") 
parser.add_option("-s", "--hgncFname", dest="hgncFname", action="store", help="the HGNC tab file on disk, default %default", default="/hive/data/outside/hgnc/111413/hgnc_complete_set.txt") 
#parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
#parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
(options, args) = parser.parse_args()

if options.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)
# ==== FUNCTIONs =====
def slurpdict(fname, comments=False, valField=1, doNotCheckLen=False, asFloat=False, otherFields=False, asInt=False, headers=False, keyAsInt=False, encoding=None, keyField=0):
    """ parse file with key -> value pair on each line, key/value has 1:1 relationship"""
    """ last field: set valField==-1, return as a dictionary key -> value """
    if fname==None or fname=="":
        return {}
    dict = {}
    f = open(fname)
    if not f:
        return dict

    if headers:
        headers = f.readline()
    for l in f:
        fs = l.strip().split("\t")
        if comments and l.startswith("#"):
            continue
        if not len(fs)>1:
            if not doNotCheckLen:
                sys.stderr.write("info: not enough fields, ignoring line %s\n" % l)
                continue
            else:
                key = fs[keyField]
                val = None
        else:
            key = fs[keyField]

            if keyAsInt:
                key = int(key)

            if not otherFields:
                val = fs[valField]
            else:
                val = fs[1:]
            
            if asFloat:
                val = float(val)
            elif asInt:
                val = int(val)
        if key not in dict:
            dict[key] = val
        else:
            sys.stderr.write("info: file %s, hit key %s two times: %s -> %s\n" %(fname, key, key, val))
    return dict

def readKeggToSym(keggDir, hgncFname):
    keggFname = join(keggDir, "genes/organisms/hsa/hsa_hgnc.list")
    logging.info("Parsing %s and %s" % (keggFname, hgncFname))
    keggToHgnc = slurpdict(keggFname)
    hgncToSym = slurpdict(hgncFname)

    keggToSym = {}
    for kegg, hgncId in keggToHgnc.iteritems():
        hgncId = hgncId.upper()
        sym = hgncToSym.get(hgncId, None)
        if sym==None:
            logging.warn("No symbol for hgnc %s referenced by KEGG" % hgncId)
            continue
        keggToSym[kegg] = sym
    return keggToSym
    
#def resolveId(entryId, geneNames, complexNames, familyNames):
    #""" given an entryId and three dicts that map it to genes, complex or family names, return a tuple
    #entryType, entryName, entryGenes
    #"""
    #if entryId in geneNames:
        #return "gene", geneNames[entryId], geneNames[entryId]
    #elif entryId in complexNames:
        #genes = complexNames[entryId]
        #return "complex", "/".join(genes), "|".join(genes)
    #elif entryId in familyNames:
        #genes = familyNames[entryId]
        #return "family", "/".join(genes), "|".join(genes)
    #else:
        #assert(False)
    
def parseKegg(filename, keggToSym):
    """ parse kegg, returns list of tuples:
    (gene1, complex1, family1, gene2, complex2, family2, interactionType, interactionSubtype 
    """
    rows = []
    tree = et.parse(filename)
    root = tree.getroot()

    # a little check first
    relEls = root.findall("relation")
    if len(relEls)==0:
        logging.warn("no relations found, skipping file %s" % filename)
        return []

    entryNames = {} # id -> tuple of (type, name, geneList)
    global keggId

    pathway = basename(filename).split(".")[0]
    title = root.attrib["title"]

    for c in root:
        if c.tag=="entry":
            # most entries are genes
            # <entry id="18" name="hsa:4091 hsa:4092" type="gene" link="http://www.kegg.jp/dbget-bin/www_bget?hsa:4091+hsa:4092">
            #<graphics name="SMAD6, HsT17432, MADH6, MADH7..." fgcolor="#000000" bgcolor="#BFFFBF"
            #type="rectangle" x="402" y="645
            entryId = c.attrib["id"]
            entryType = c.attrib["type"]
            keggAccIds = c.attrib["name"].split()
            graphName = None
            g = c.find("graphics")
            if "name" in g.attrib:
                graphName = g.attrib["name"].split()[0].strip(",")

            if entryType=="map":
                # ignore links to other maps
                continue

            elif entryType=="gene":
                symbols = []
                for e in keggAccIds:
                    sym = keggToSym.get(e, None)
                    if sym==None:
                        logging.warn("No symbol for %s" % e)
                    else:
                        symbols.append(sym)
                syms = tuple(sorted(symbols))

                if len(syms)==1:
                    eType = "gene"
                else:
                    eType = "family"
                entryNames[entryId] = eType, "", "|".join(syms)

            elif entryType in ["ortholog", "compound"]:
                entryNames[entryId] = (entryType, "/".join(keggAccIds), "-")

            elif entryType=="group":
                compIds = []
                compEls = c.findall("component")
                for ce in compEls:
                    compIds.append(ce.attrib["id"])
                assert(len(compIds)!=0)

                compGenes = set()
                for e in compIds:
                    genes = []

                    # convert id to list of genes
                    syms = entryNames.get(e, (None, None))[-1]
                    if syms==None:
                        #logging.warn("No symbol for %s" % e)
                        assert(False)
                        continue
                    syms = syms.split("|")
                    compGenes.update(syms)
                compGenes = sorted(compGenes)
                #compName = "/".join(compGenes)
                #entryName = "/".join([geneNames[i] for i in compIds])

                #complexIds[entryId] = (compIds, compName)

                #geneNames[entryId] = (entryName, entryType, compName)
                entryNames[entryId] = ("complex", "", "|".join(compGenes))
            else:
                logging.error("Unkown entry type %s" % entryType)
                assert(False)

        elif c.tag=="relation":
            #     <relation entry1="48" entry2="35" type="PPrel">  
            # <subtype name="activation" value="--&gt;"/>  
            # </relation>  

            id1 = c.attrib["entry1"]
            id2 = c.attrib["entry2"]
            relType = intTypeNames[c.attrib["type"]]
            se = c.find("subtype")
            if se != None:
                relSubtype = se.attrib["name"]
            else:
                relSubtype = ""

            id1Type, id1Name, id1Genes = entryNames[id1]
            id2Type, id2Name, id2Genes = entryNames[id2]
            #id1Names = geneNames.get(id1, [""])
            #id2Names = geneNames.get(id2, [""])
            #complex1Name = complexNames.get(id1, "")
            #complex2Name = complexNames.get(id2, "")
            keggId += 1
            keggIdStr = "kegg%d" % keggId
            #for id1Name in id1Names:
                #for id2Name in id2Names:
            #row = [keggIdStr, id1Name, complex1Name, id2Name, complex2Name, relType, relSubtype, pathway]
            if id1Name=="-":
                continue
            if id2Name=="-":
                continue
            row = [keggIdStr, id1Type, id1Name, id1Genes, id2Type, id2Name, id2Genes, relType, relSubtype, "kegg", pathway, title, ""]
            #print row
            rows.append(row)

    return rows

# ----------- MAIN --------------
if args==[]:
    parser.print_help()
    exit(1)

filenames = args

keggToSym = readKeggToSym(options.keggDir, options.hgncFname)

print "#"+"\t".join(headers)
for filename in filenames:
    logging.debug(filename)
    rows = parseKegg(filename, keggToSym)
    for row in rows:
        print "\t".join(row)
