#!/usr/bin/env python2.7

import logging, sys, optparse, string, os
from collections import defaultdict
from os.path import join, basename, dirname, isfile

# links are generated like this:
# http://biocyc.org/HUMAN/NEW-IMAGE?type=PATHWAY&object=BETA-ALA-DEGRADATION-I-PWY
# http://www.reactome.org/cgi-bin/link?SOURCE=Reactome&ID=REACT_1
# http://pantherdb.org/pathway/pathwayDiagram.jsp?catAccession=P04376

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = optparse.OptionParser("usage: %prog [options] superpathwayDirectory - parse Kyrle's superpathway input files into tab format excluding PID") 

parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages") 
parser.add_option("-g", "--genes", dest="toGenes", action="store", help="resolve a name to genes, print result and stop") 
#parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
#parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
(options, args) = parser.parse_args()

if options.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)
# ==== FUNCTIONs =====
    
def parseSpf(inFname, comp, fam, edges, ign):
    """ parse spf file and return three dicts, comp is id->list of symbols, fam is id->list of symbols and interact
    is a list of (cause, theme, subType).
    """
    if not isfile(inFname):
        return comp, fam, edges, ign

    logging.info(inFname)
    for line in open(inFname):
        # TRAF2	TNF-alpha_TNF-R1_TRAPP_RIP1_TRAF2_Complex_(complex)	component>
        pwRef = basename(dirname(inFname)).split(".")[0]
        srcDb, acc = string.split(pwRef, "_", 1)

        parts = line.rstrip("\n").split("\t")
        lastField = parts[-1]
        if parts[0]=="rna":
            name = parts[-1]
            ign.add(name)
        if lastField == "component>":
            member, complex = parts[:2]
            comp[complex].add(member)
        elif lastField == "member>":
            member, famName = parts[:2]
            fam[famName].add(member)
        elif lastField.startswith("-a"):
            cause, theme = parts[:2]
            if lastField[-1]==">":
                relType = "activate"
            else:
                relType = "inhibits"
            edges.append( (cause, theme, relType, srcDb, acc) )
    return comp, fam, edges, ign

def nameToGenes(name, comp, fam):
    """ recursively resolve a name to the list of all gene symbols
    associated to it through either complexes or families
    """
    genes = []
    #print name, comp, fam
    if name in comp:
        names = comp[name]
        for n in names:
            genes.extend(nameToGenes(n, comp, fam))
        return genes
    elif name in fam:
        names = fam[name]
        for n in names:
            genes.extend(nameToGenes(n, comp, fam))
        return genes
    else:
        return [name]

def resolveRef(name, comp, fam):
    """ given the comp and fam dicts, return a tuple ("gene"|"family"|"complex", name, |-sep gene list) """
    if name in comp:
        ent = "complex"
    elif name in fam:
        ent = "family"
    else:
        ent = "gene"

    geneList = nameToGenes(name, comp, fam)
    return (ent, name, "|".join(geneList))
# ----------- MAIN --------------
if args==[]:
    parser.print_help()
    exit(1)

inDir = args[0]

comp = defaultdict(set)
fam = defaultdict(set)
ign = set()
edges = []

i = 0
if isfile(inDir):
    inFnames = [inDir]
else:
    inDirs = os.listdir(inDir)
    inFnames = [join(inDir, subDir, "graph.spf") for subDir in inDirs]

for inFname in inFnames:
    comp, fam, edges, ign = parseSpf(inFname, comp, fam, edges, ign)
    i+=1
    #if i==100:
        #break

if options.toGenes:
    print nameToGenes(options.toGenes, comp, fam)
    sys.exit(0)

for (cause, theme, relType, srcDb, acc) in edges:
    if srcDb=="PID":
        continue
    if cause in ign or theme in ign:
        continue
    row = list(resolveRef(cause, comp, fam))
    row.extend(resolveRef(theme, comp, fam))
    row.append(relType)
    row.append("")
    row.append(srcDb)
    row.append(acc)
    row.append("")
    print "\t".join(row)
    #print row
