#!/usr/bin/env python

# load default python packages
import logging, optparse, sys, glob, gzip, collections, copy, gzip, os, doctest, re
from os.path import *
from collections import defaultdict

try:
    from lxml import etree # if this fails, comment out this line and uncomment the next one. Or do the 'pip install' below.
    #import xml.etree.cElementTree as etree # if using this line, search for cElementTree in this file and comment out the other part
except:
    raise Exception("lxml library not found. install elementtree with 'sudo apt-get install libxml2-dev libxslt-dev python-dev; pip install lxml' or just 'apt-get install python-lxml'. Or read the source code to get rid of the dependency.")

debugMode=False

# --- FASTA FILES ---
class FastaReader:
    """ a class to parse a fasta file
    Example:
        fr = FastaReader(filename)
        for (id, seq) in fr.parse():
            print id,seq """

    def __init__(self, fname):
        if hasattr(fname, 'read'):
            self.f = fname
        elif fname=="stdin":
            self.f=sys.stdin
        elif fname.endswith(".gz"):
            self.f=gzip.open(fname)
        else:
            self.f=open(fname)
        self.lastId=None

    def parse(self):
      """ Generator: returns sequences as tuple (id, sequence) """
      lines = []

      for line in self.f:
              if line.startswith("\n") or line.startswith("#"):
                  continue
              elif not line.startswith(">"):
                 lines.append(line.replace(" ","").strip())
                 continue
              else:
                 if len(lines)!=0: # on first >, seq is empty
                       faseq = (self.lastId, "".join(lines))
                       self.lastId=line.strip(">").strip()
                       lines = []
                       yield faseq
                 else:
                       if self.lastId!=None:
                           sys.stderr.write("warning: when reading fasta file: empty sequence, id: %s\n" % line)
                       self.lastId=line.strip(">").strip()
                       lines=[]

      # if it's the last sequence in a file, loop will end on the last line
      if len(lines)!=0:
          faseq = (self.lastId, "".join(lines))
          yield faseq
      else:
          yield (None, None)

def parseFastaAsDict(fname, inDict=None):
    if inDict==None:
        inDict = {}
    fname2 = fname.replace(".gz","")
    if isfile(fname2):
        logging.warn("Preferring unzipped file %s" % fname2)
        fname = fname2

    fr = FastaReader(fname)
    for (id, seq) in fr.parse():
        if id in inDict:
            print inDict
            print inDict[id]
            raise Exception("%s already seen before" % id)
        inDict[id]=seq
    return inDict

class ProgressMeter:
    """ prints a message "x%" every stepCount/taskCount calls of taskCompleted()
    """
    def __init__(self, taskCount, stepCount=20, quiet=False):
        self.taskCount=taskCount
        self.stepCount=stepCount
        self.tasksPerMsg = taskCount/stepCount
        self.i=0
        self.quiet = quiet
        #print "".join(9*["."])

    def taskCompleted(self, count=1):
        if self.quiet and self.taskCount<=5:
            return
        #logging.debug("task completed called, i=%d, tasksPerMsg=%d" % (self.i, self.tasksPerMsg))
        if self.tasksPerMsg!=0 and self.i % self.tasksPerMsg == 0:
            donePercent = (self.i*100) / self.taskCount
            #print "".join(5*[chr(8)]),
            sys.stderr.write("%.2d%% " % donePercent)
            sys.stderr.flush()
        self.i += count
        if self.i==self.taskCount:
            print ""

def setupLogging(progName, options, parser=None, logFileName=None, \
        debug=False, fileLevel=logging.DEBUG, minimumLog=False, fileMode="w"):
    """ direct logging to a file and also to stdout, depending on options (debug, verbose, jobId, etc) """
    assert(progName!=None)
    global debugMode

    stdoutLevel=logging.INFO
    if options==None:
        stdoutLevel=logging.DEBUG

    elif options.debug or debug:
        stdoutLevel=logging.DEBUG
        debugMode = True

    rootLog = logging.getLogger('')
    rootLog.setLevel(fileLevel)
    logging.root.handlers = []

    # setup file logger
    if logFileName != None and logFileName!="":
        fh = logging.FileHandler(logFileName)
        fh.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        rootLog.addHandler(fh)

    # define a handler which writes messages to sys.stderr
    console = logging.StreamHandler()
    # set a format which is simpler for console use
    formatter = logging.Formatter('%(levelname)-8s-%(message)s')
    # tell the handler to use this format
    console.setFormatter(formatter)
    console.setLevel(stdoutLevel)
    # make sure that the root logger gets verbose messages 
    logging.getLogger('').setLevel(min(stdoutLevel, fileLevel))
    # add the handler to the root logger
    rootLog.addHandler(console)


# UNIPROT PARSING 

# only parse these feature types
# anything else triggers a warning at the end of the parse
# As of 2017, these are all types of annotations
featTypes = {
    "splice variant" : "splicing",
    "sequence variant": "variant",
    "sequence conflict": "conflict",
    "mutagenesis site": "mutagen",
    "modified residue": "modif",
    "cross-link": "cross-link",
    "region of interest": "interest",
    "short sequence motif": "motif",
    "metal ion-binding site": "ion-binding",
    "site": "site",
    "topological domain" : "topo",
    "transmembrane region" : "transmem",
    "disulfide bond" : "disulf bond",
    "glycosylation site" : "glyco",
    "binding site" : "bind",
    "active site" : "enzyme act site",
    "signal peptide" : "signal pep",
    "transit peptide" : "trans pep",
    "calcium-binding region" : "calcium bind",
    "lipid moiety-binding region" : "lipid",
    "propeptide" : "propep",
    "intramembrane region" : "intramem",
    "peptide" : "peptide",
    "nucleotide phosphate-binding region" : "nucl phos bind",
    "helix" : "helix",
    "chain" : "chain",
    "coiled-coil region" : "coiled-coil",
    "turn" : "turn",
    "strand" : "beta",
    "domain" : "domain",
    "zinc finger region" : "zinc finger",
    "repeat" : "repeat",
    "compositionally biased region" : "biased",
    "initiator methionine" : "init Met",
    "non-standard amino acid" : "non-std",
    "non-consecutive residues" : "non-consec",
    "unsure residue" : "unsure",
    "DNA-binding region" : "DNA-binding",
    "non-terminal residue" : "nonTerm"
}

# main record info
entryHeaders = ["dataset", "acc", "mainIsoAcc", "orgName", "orgCommon", "taxonId", "name", "accList", \
    "protFullNames", "protShortNames", "protAltFullNames", "protAltShortNames", \
    "geneName", "geneSynonyms", "isoNames", \
    "geneOrdLocus", "geneOrf", \
    "hgncSym", "hgncId", "refSeq", "refSeqProt", "entrezGene", "ensemblGene", "ensemblProt", "ensemblTrans", \
    "kegg", "emblMrna", "emblMrnaProt", "emblDna", "emblDnaProt", \
    "pdb", "ec", \
    "uniGene", "omimGene", "omimPhenotype", "subCellLoc", "functionText", "isoIds"]
EntryRec = collections.namedtuple("uprec", entryHeaders)

# all annotations get parsed into this format
annotHeaders = ["acc", "mainIsoAcc", "varId", "featType", "shortFeatType", "begin", "end", "origAa", "mutAa", "dbSnpId", "disRelated", "disease", "disCode", "pmid", "comment"]
AnnotRec = collections.namedtuple("mutrec", annotHeaders)

# references from record
refHeaders = ["name", "citType", "year", "journal", "vol", "page", \
        "title", "authors", "doi", "pmid", "scopeList"]
RefRec = collections.namedtuple("refRec", refHeaders)
emptyRef = dict(zip(refHeaders, len(refHeaders)*[""]))

def strip_namespace_inplace(etree, namespace=None,remove_from_attr=True):
    """ Takes a parsed ET structure and does an in-place removal of all namespaces,
        or removes a specific namespacem (by its URL).

        Can make node searches simpler in structures with unpredictable namespaces
        and in content given to be non-mixed.

        By default does so for node names as well as attribute names.
        (doesn't remove the namespace definitions, but apparently
         ElementTree serialization omits any that are unused)

        Note that for attributes that are unique only because of namespace,
        this may attributes to be overwritten. 
        For example: <e p:at="bar" at="quu">   would become: <e at="bar">

        I don't think I've seen any XML where this matters, though.
    """
    if namespace==None: # all namespaces                               
        for elem in etree.getiterator():
            tagname = elem.tag
            if not isinstance(elem.tag, basestring):
                continue
            if tagname[0]=='{':
                elem.tag = tagname[ tagname.index('}',1)+1:]

            if remove_from_attr:
                to_delete=[]
                to_set={}
                for attr_name in elem.attrib:
                    if attr_name[0]=='{':
                        old_val = elem.attrib[attr_name]
                        to_delete.append(attr_name)
                        attr_name = attr_name[attr_name.index('}',1)+1:]
                        to_set[attr_name] = old_val
                for key in to_delete:
                    elem.attrib.pop(key)
                elem.attrib.update(to_set)

    else: # asked to remove specific namespace.
        ns = '{%s}' % namespace
        nsl = len(ns)
        for elem in etree.getiterator():
            if elem.tag.startswith(ns):
                elem.tag = elem.tag[nsl:]

            if remove_from_attr:
                to_delete=[]
                to_set={}
                for attr_name in elem.attrib:
                    if attr_name.startswith(ns):
                        old_val = elem.attrib[attr_name]
                        to_delete.append(attr_name)
                        attr_name = attr_name[nsl:]
                        to_set[attr_name] = old_val
                for key in to_delete:
                    elem.attrib.pop(key)
                elem.attrib.update(to_set)


def parseDiseases(fname):
    " parse the file humanDiseases.txt from uniprot to resolve disease IDs to disease names "
    logging.info("Parsing %s" % fname)
    dis = {}
    for line in open(fname).read().splitlines():
        if line.startswith("ID"):
            name = line[5:].strip(".")
        if line.startswith("AR"):
            code = line[5:].strip(".")
            dis[code]=name
    logging.info("read %d disease code -> disease name mappings" % len(dis))
    return dis

def findSaveList(el, path, dataDict, key, attribKey=None, attribVal=None, useAttrib=None, subSubEl=None):
    """ find all text of subelemets matching path with given optionally attrib and save into dataDict with key
    You can specify a subSubEl of the element to get the text from.
    """
    l = []
    for se in el.findall(path):
        if attribKey!=None and se.attrib.get(attribKey, None)!=attribVal:
            continue
        if useAttrib:
            val = se.attrib[useAttrib]
        else:
            if subSubEl:
                val = se.find(subSubEl).text
            else:
                val = se.text
        l.append(val)
    s = "|".join(l)
    dataDict[key] = s

def openOutTabFile(subDir, outName, headers):
    " create outdir and open outfile, write headers "
    #subDir = join(outDir, outSubDir) 
    if not isdir(subDir):
        logging.info("Creating dir %s" % subDir)
        os.makedirs(subDir)
    outPath = join(subDir, outName)
    logging.debug("Writing output to %s" % outPath)
    ofh = open(outPath, "w")
    ofh.write("\t".join(headers)+"\n")
    return ofh

def findDisCodes(text, disToName):
    """ find disease codes in text, return as a set of disease codes 
    >>> findDiseases("Defects in HAL are the cause of histidinemia (HISTID) ")
    set(['HISTID'])
    """
    disSet = set()
    for m in re.finditer("[(]([a-zA-Z0-9- ]+)[)]", text):
        word = m.group(1)
        if word in disToName:
            disSet.add(word)
    return disSet

# original code tried to guess the acronyms.
# these days, UniProt provides a file with most of the acronyms
# leaving it here in case that UniProt ever decides to stop updating their acronyms

#def findDiseases(text):
#    """ find disease codes and their names in text, return as dict code -> name 
#    >>> findDiseases("Defects in CEACAM16 are the cause of deafness autosomal dominant type 4B (DFNA4B) [MIM:614614].")
#    {'DFNA4B': 'deafness autosomal dominant type 4B'}
#    >>> findDiseases("Defects in ALX4 are the cause of parietal foramina 2 (PFM2) [MIM:609597]; also known as foramina parietalia permagna (FPP). PFM2 is an autosomal dominant disease characterized by oval defects of the parietal bones caused by deficient ossification around the parietal notch, which is normally obliterated during the fifth fetal month. PFM2 is also a clinical feature of Potocki-Shaffer syndrome.")
#    {'PFM2': 'parietal foramina 2', 'FPP': 'foramina parietalia permagna'}
#
#    # disease is only one word, but long enough
#    >>> findDiseases("Defects in HAL are the cause of histidinemia (HISTID) ")
#    {'HISTID': 'histidinemia'}
#    """
#    result = {}
#    phrases = re.split("[;.] ", text)
#    notDisease = set(["of", "with", "to", "as", "or", "also", "in"])
#
#    for phrase in phrases:
#        words = phrase.split()
#        revWords = reversed(words)
#
#        grabWords = False
#        disWords = []
#        disCode = None
#        # go backwords over words and look for acronym, then grab all words before that
#        # until we find a common English word
#        for word in revWords:
#            m = re.match("[(]([A-Z0-9-]+)[)]", word)
#            if m!=None:
#                disCode = m.group(1)
#                grabWords = True
#                continue
#
#            if word in notDisease and (len(disWords)>1 or len("".join(disWords))>=9):
#                disName = " ".join(list(reversed(disWords)))
#                if disCode==None:
#                    logging.debug("Found disease %s, but no code for it" % disName)
#                    continue
#                result[disCode] = disName
#                disCode = None
#                disWords = []
#                grabWords = False
#
#            if grabWords:
#                disWords.append(word)
#
#    return result

#def parseDiseaseComment(entryEl):
#    """ return two dicts 
#    one with evidence code -> disease code
#    one with disease code -> disease name 
#    """
#    disRefs = {}
#    disCodes = {}
#    for commentEl in entryEl.findall("comment"):
#        textEl = commentEl.find("text")
#        if commentEl.attrib["type"]=="disease":
#            refStr = commentEl.attrib.get("evidence", None)
#            # website xml is different, has evidence attribute on text element
#            if refStr==None:
#                refStr = textEl.attrib.get("evidence", None)
#                if refStr==None:
#                    continue
#
#            refs = refStr.split(" ")
#
#            text = textEl.text
#            logging.debug("Disease comment: %s, evidence %s" % (text, refStr))
#            disCodes.update(findDiseases(text))
#
#            for refId in refs:
#                disRefs[refId] = disCodes
#
#    logging.debug("Found disease evidences: %s" % disRefs)
#    logging.debug("Found disease names: %s" % disCodes)
#    return disRefs, disCodes

def parseDiseaseComment(entryEl, disToName):
    """ 
    parse the general comments, disease section from up record 
    return evidence codes that refer to diseases 
    also return disease codes 
    """
    disRefs = {}
    disCodes = set()
    for commentEl in entryEl.findall("comment"):
        textEl = commentEl.find("text")
        if commentEl.attrib["type"]=="disease":
            refStr = commentEl.attrib.get("evidence", None)
            # website xml is different, has evidence attribute on text element
            if refStr==None:
                refStr = textEl.attrib.get("evidence", None)
                if refStr==None:
                    continue

            refs = refStr.split(" ")

            text = textEl.text
            logging.debug("Disease comment: %s, evidence %s" % (text, refStr))
            disCodes.update(findDisCodes(text, disToName))

            for refId in refs:
                disRefs[refId] = disCodes

    logging.debug("Found disease evidences: %s" % disRefs)
    return disRefs, disCodes

def parseIsoforms(entryEl, mainId):
    """ parse sequences of isoforms, returns lists: isoIds, isoNames, dispId 
    dispId is mainId<spc>isoId for records with isoforms
    """
    isoDefined = False
    isoIds = []
    isoNames = []
    mainIsoId = mainId
    for isoEl in entryEl.findall("comment/isoform"):
        isoDefined = True
        # get id
        idEl = isoEl.find("id")
        isoId = idEl.text

        # get names (aka transcript synonyms), currently a reassignment to transcripts is impossible from my files, I just collect them for now
        for nameEl in isoEl.find("name"):
            isoNames.append(nameEl.text)
        seqEl = isoEl.find("sequence")

        # get sequences
        seqType = seqEl.attrib["type"]
        if seqType=="displayed":
            # the main sequence also is an "isoform", it's very strange
            # one of the isoform sequences is "displayed" = main isoform
            # we have already covered these
            mainIsoId = isoId
        elif seqType=="described":
            isoIds.append(isoId)
        else:
            assert(seqType in ["external", "not described"]) # weird Uniprot: they refer to sequences that they do not have

    #assert(len(seqs)==len(isoNames)) # will often be different, one transcript can have many names

    return isoIds, isoNames, mainIsoId

def parseDbRefs(entryEl):
    " return dict with db -> id (various special cases) "
    dbRefs = defaultdict(set)
    dbRefs["emblMrna"] =  []
    dbRefs["emblMrnaProt"] =  []
    dbRefs["emblDna"] =  []
    dbRefs["emblDnaProt"] =  []
    for dbRefEl in entryEl.findall("dbReference"):
        db = dbRefEl.attrib["type"]
        mainId = dbRefEl.attrib["id"]
        if db=="EMBL": # special case, don't add yet
            emblId = mainId
        else:
            dbRefs[db].add(mainId)
        propEls = dbRefEl.findall("property")
        emblProtId = "na"
        id = None
        for propEl in propEls:
            propType = propEl.attrib["type"]
            propDb = db
            if (db, propType) ==("RefSeq", "nucleotide sequence ID"):
                id = propEl.attrib["value"]
                propDb = "refseqNucl"
            elif db=="HGNC" and propType=="gene designation":
                id = propEl.attrib["value"]
                propDb = "hgncGene"
            elif db=="Ensembl" and propType=="gene ID":
                id = propEl.attrib["value"]
                propDb = "ensemblGene"
            elif db=="Ensembl" and propType=="protein sequence ID":
                id = propEl.attrib["value"]
                propDb = "ensemblProt"
            elif db=="EMBL" and propType=="protein sequence ID":
                emblProtId = propEl.attrib["value"]
                continue # don't add yet
            elif db=="MIM" and propType=="type":
                omimCat = propEl.attrib["value"]
                if omimCat=="phenotype":
                    dbRefs["omimPhenotype"].add(mainId)
                elif omimCat=="gene" or omimCat=="gene+phenotype":
                    dbRefs["omimGene"].add(mainId)
                else:
                    assert(False)
            elif db=="EMBL" and propType=="molecule type":
                val = propEl.attrib["value"]
                if val=="mRNA":
                    # add now
                    dbRefs["emblMrna"].append(emblId)
                    dbRefs["emblMrnaProt"].append(emblProtId)
                else:
                    dbRefs["emblDna"].append(emblId)
                    dbRefs["emblDnaProt"].append(emblProtId)
                continue # don't add any id
            else:
                id = dbRefEl.attrib["id"]
            if id!=None:
                dbRefs[propDb].add(id)

    result = {}
    for db, valList in dbRefs.iteritems():
        result[db] = "|".join(valList)
        
    logging.debug("dbRefs: %s" % result)
    return result

def splitAndResolve(disName, disCodes, splitWord):
    " split and split word, try to resolve via disCodes and rejoin again "
    subDises = disName.split(splitWord)
    newDises = []
    for subDis in subDises:
        subDis = subDis.strip()
        if subDis in disCodes:
            newDises.append(disCodes[subDis])
        else:
            newDises.append(subDis)
    disName = ",".join(newDises)
    return disName

def parseFeatDesc(text, disToName):
    """ 
    parse the description of a feature to find code name of disease, snpId and comments 
    return tuple: (disease name, dbSnpId, otherComments)
    >>> parseFeatDesc("In sporadic cancers; somatic mutation; dbSNP:rs11540654.", {})
    ('sporadic cancers', 'rs11540654', 'somatic mutation')
    >>> parseFeatDesc("In RIEG1; pointless comment", {"RIEG1" : "Axel-Riegerfeldt syndrome"})
    ('Axel-Riegerfeldt syndrome', '', 'pointless comment')
    """
    # find disease name and try to resolve via disToNames
    logging.debug("Feature description: %s " % (text))
    text = text.strip(".").strip()
    parts = text.split("; ")
    disCode = ""
    comments = []
    for part in parts:
        part = part.replace("a patient with", "")
        part = part.replace("in a ", "in ")
        partLow = part.lower()
        if partLow.startswith("in ") and "dbSNP" not in part and "allele" not in part:
            disCode = " ".join(part.split()[1:])
            # some entries contain two disease names
        else:
            if "dbSNP" not in part:
                comments.append(part)
                    
    # we got a plain disease code
    if disCode in disToName:
        disLongName = disToName[disCode]
    # or two dis codes with and
    elif " and " in disCode:
        disLongName = splitAndResolve(disCode, disToName, " and ")
    else:
        # there are dis code somewhere inside the text
        intDisCodes = findDisCodes(disCode, disToName)
        if len(intDisCodes)!=0:
            disLongName = disCode
            disCode = ",".join(intDisCodes)
        # ok nothing worked, keep it as it is
        else:
            disLongName = disCode

    # find snpId
    snpId = ""
    for m in re.finditer("dbSNP:(rs[0-9]+)", text):
        if m!=None:
            #assert(snpId=="")
            snpId = m.group(1)

    logging.debug("Disease: %s, snpId: %s" % (disLongName, snpId))
    return disCode, disLongName, snpId, "; ".join(comments)


ignoredTypes = collections.Counter()

def parseFeatures(entryEl, disRefs, defaultDisCodes, disToName, evidPmids, mainIsoAcc):
    " go over features and yield annotation records "

    acc = entryEl.find("accession").text

    mutations = []
    for featEl in entryEl.findall("feature"):
        featType = featEl.attrib["type"]
        if featType not in featTypes:
            ignoredTypes[featType] += 1 
            continue

        if featType in ["sequence variant"]:
            isVariant = True
        else:
            isVariant = False

        shortFeatType = featTypes[featType]
        logging.debug("type: %s" % featType)

        varId = featEl.attrib.get("id", "")
        logging.debug("Variant ID %s" % varId)

        origEl = featEl.find("original")
        if origEl==None:
            orig = ""
        else:
            orig = origEl.text

        varEl = featEl.find("variation")
        if varEl==None:
            variant = ""
        else:
            variant = varEl.text
            logging.debug("residue change: %s->%s" % (orig, variant))

        posEl = featEl.find("location/position")
        # if features have only a single position, they have only <position position="xxx">
        # we convert them to begin=xxx and end=xxx+1
        if posEl!=None:
            begin = posEl.attrib["position"]
            end = str(int(begin)+1)
        else:
            beginEl = featEl.find("location/begin")
            begin = beginEl.attrib.get("position", None)
            if begin==None:
                logging.debug("Unknown start, skipping a feature")
                continue
            endEl = featEl.find("location/end")
            end = endEl.attrib.get("position", None)
            if end==None:
                logging.debug("Unknown end, skipping a feature")
                continue
            end = str(int(end)+1) # UniProt is 1-based, open-end

        desc = featEl.attrib.get("description", None)
        if desc==None:
            desc = ""
        if "sulfinic" in desc:
            shortFeatType = "sulfo"

        descWords = desc.split()
        if len(descWords)>0:
            desc1 = descWords[0].lower()
            if "phos" in desc1:
                shortFeatType = "phos"
            elif "acetyl" in desc1:
                shortFeatType = "acetyl"
            elif "methyl" in desc1:
                shortFeatType = "methyl"
            elif "lipo" in desc1:
                shortFeatType = "lipo"
            elif "hydroxy" in desc1:
                shortFeatType = "hydroxy"
            elif "nitro" in desc1:
                shortFeatType = "nitro"

        evidStr = featEl.attrib.get("evidence", "")
        logging.debug("annotation pos %s-%s, desc %s, evidence %s" % (begin, end, desc, evidStr))
        desc = desc.strip("() ")
        evidList = evidStr.split()

        if isVariant:
            # only do this for mutations
            disCode, disName, snpId, comments = parseFeatDesc(desc, disToName)
            # if no disease annotated to feature, use the one from the record
            if disCode=="" and len(defaultDisCodes)==1:
                disCode = list(defaultDisCodes)[0]
                disName = disToName.get(disCode, disCode)+" (not annotated on variant but on gene record)"
                disCode = disCode + "?"
        else:
            disCode, disName, snpId, comments = "", "", "", desc

        annotPmids = []
        if disCode!="":
            diseaseRelated = "disRelated"
        else:
            diseaseRelated = "noEvidence"

        for evidId in evidList:
            if evidId in disRefs:
                diseaseRelated="disRelated"
            else:
                diseaseRelated="notDisRelated"
                logging.debug("evidence is not a disease evidence or blacklisted, check description")

            pmids = evidPmids.get(evidId, [])
            assert(len(pmids)<=1)
            if len(pmids)>0:
                pmid = list(pmids)[0]
                annotPmids.append(pmid)

        annot = AnnotRec(acc, mainIsoAcc, varId, featType, shortFeatType, begin, end, orig, variant, snpId, diseaseRelated, disName, disCode, ",".join(annotPmids), comments)
        logging.debug("Accepted annotation: %s" % str(annot))

        yield annot

def parseEvidence(entryEl):
    " return a dict with evidCode -> PMID "
    result = {}
    for evidEl in entryEl.findall("evidence"):
        evidCode = evidEl.attrib["key"]
        for dbRefEl in evidEl.findall("source/dbReference"):
            dbType = dbRefEl.attrib["type"]
            if dbType=="PubMed":
                pmid = dbRefEl.attrib["id"]
                result.setdefault(evidCode, [])
                result[evidCode].append(pmid)
    return result
    
def parseAnnotations(entryEl, mainIsoAcc, disToName):
    " return MutRecs with disease associated variants "
    # parse the general record comment about diseases
    disRefs, allDiseaseCodes = parseDiseaseComment(entryEl, disToName)

    acc = entryEl.find("accession").text
    logging.debug("Diseases in %s" % acc)

    evidPmids = parseEvidence(entryEl)
    annotRecs = list(parseFeatures(entryEl, disRefs, allDiseaseCodes, disToName, evidPmids, mainIsoAcc))
    return annotRecs

def parseRecInfo(entryEl, entry, isoSeqs):
    """parse uniprot general record info into entry dict
    use isoform sequences from isoSeqs
    only process certain taxonIds
    """
    dataset = entryEl.attrib["dataset"]
    entry["dataset"] = dataset

    findSaveList(entryEl, "name", entry, "name")
    findSaveList(entryEl, "accession", entry, "accList")
    acc = entry["accList"].split("|")[0]
    entry["acc"] = acc

    logging.debug("Parsing rec info for acc %s" % acc)

    findSaveList(entryEl, "protein/recommendedName/fullName", entry, "protFullNames")
    findSaveList(entryEl, "protein/recommendedName/shortName", entry, "protShortNames")
    findSaveList(entryEl, "protein/alternativeName/fullName", entry, "protAltFullNames")
    findSaveList(entryEl, "protein/alternativeName/shortName", entry, "protAltShortNames")
    findSaveList(entryEl, "gene/name", entry, "geneName", attribKey="type", attribVal="primary")
    findSaveList(entryEl, "gene/name", entry, "geneSynonyms", attribKey="type", attribVal="synonym")
    findSaveList(entryEl, "gene/name", entry, "geneOrdLocus", attribKey="type", attribVal="ordered locus")
    findSaveList(entryEl, "gene/name", entry, "geneOrf", attribKey="type", attribVal="ORF")
    findSaveList(entryEl, "organism/name", entry, "orgName", attribKey="type", attribVal="scientific")
    findSaveList(entryEl, "organism/name", entry, "orgCommon", attribKey="type", attribVal="common")
    findSaveList(entryEl, "organism/dbReference", entry, "taxonId", useAttrib="id")
    findSaveList(entryEl, "comment/isoform/id", entry, "isoIds")
    findSaveList(entryEl, "comment/isoform/name", entry, "isoNames")
    findSaveList(entryEl, "comment/subcellularLocation/location", entry, "subCellLoc")
    findSaveList(entryEl, "comment", entry, "functionText", attribKey="type", attribVal="function", subSubEl="text")

    mainSeq = entryEl.find("sequence").text
    #entry["mainSeq"] = mainSeq

    dbRefs = parseDbRefs(entryEl)

    isoIds, isoNames, mainIsoId = parseIsoforms(entryEl, acc)

    entry["mainIsoAcc"] = mainIsoId

    seqs = []
    seqs.append( (mainIsoId+" isRefOf "+acc, mainSeq) )

    for isoId in isoIds:
        if isoId not in isoSeqs:
            logging.warn("No sequence for isoform %s" % isoId)
            continue
        seqs.append( (isoId, isoSeqs[isoId]) )

    entry["hgncSym"] = dbRefs.get("hgncGene", "")
    entry["hgncId"] = dbRefs.get("HGNC", "")
    entry["refSeq"] = dbRefs.get("refseqNucl", "")
    entry["refSeqProt"] = dbRefs.get("RefSeq", "")
    entry["ensemblProt"] = dbRefs.get("ensemblProt", "")
    entry["ensemblGene"] = dbRefs.get("ensemblGene", "")
    entry["ensemblTrans"] = dbRefs.get("Ensembl", "")
    entry["entrezGene"] = dbRefs.get("GeneID", "")
    entry["kegg"] = dbRefs.get("KEGG", "")
    entry["uniGene"] = dbRefs.get("UniGene", "")
    entry["omimGene"] = dbRefs.get("omimGene", "")
    entry["omimPhenotype"] = dbRefs.get("omimPhenotype", "")
    entry["emblMrna"] = dbRefs.get("emblMrna", "") # mrnas
    entry["emblMrnaProt"] = dbRefs.get("emblMrnaProt", "") # the protein accessions for mrnas
    entry["emblDna"] = dbRefs.get("EmblDna", "") # anything not an mrna
    entry["emblDnaProt"] = dbRefs.get("EmblDnaProt", "") # protein accessions for non-mrnas
    entry["pdb"] = dbRefs.get("PDB", "")
    entry["ec"] = dbRefs.get("EC", "")
        
    entry["isoIds"]="|".join(isoIds)
    #entry["isoSeqs"]="|".join(seqs)
    entry["isoNames"]="|".join(isoNames)

    entryRow = EntryRec(**entry)
    return entryRow, seqs

def parseRefInfo(entryEl, recName):
    for refEl in entryEl.findall("reference"):
        ref = copy.copy(emptyRef)
        ref["name"] = recName
        citEl = refEl.find("citation")
        ref["citType"] = citEl.attrib["type"]
        year = citEl.attrib.get("date", "")
        ref["year"] = year.split("-")[0]
        ref["journal"] = citEl.attrib.get("name", "")
        if ref["journal"]=="":
            ref["journal"] = citEl.attrib.get("db", "") # for submissions
        ref["vol"] = citEl.attrib.get("volume", "")
        ref["page"] = citEl.attrib.get("first", "")
        for titleEl in citEl.findall("title"):
            ref["title"] = titleEl.text
        authorList = []
        for personEl in citEl.findall("authorList/person"):
            if "name" in personEl.attrib:
                name = personEl.attrib["name"]
                name = name.replace(" ", ",", 1)
                authorList.append(name)
        ref["authors"]=";".join(authorList)
        for dbRefEl in citEl.findall("dbReference"):
            if "type" in dbRefEl.attrib:
                if dbRefEl.attrib["type"]=="DOI":
                    ref["doi"] = dbRefEl.attrib["id"]
                if dbRefEl.attrib["type"]=="PubMed":
                    ref["pmid"] = dbRefEl.attrib["id"]

        findSaveList(refEl, "scope", ref, "scopeList")
        refRow = RefRec(**ref)
        yield refRow

def readIsoforms(inDir, db):
    " return all isoform sequences as dict isoName (eg. P48347-2) -> sequence "
    if db=="swissprot":
        isoFname = join(inDir, "uniprot_sprot_varsplic.fasta.gz")
    elif db=="trembl":
        isoFname = join(inDir, "uniprot_trembl.fasta.gz")
    else:
        assert(False)

    logging.info("reading isoform sequences from %s (or non-gz version)" % isoFname)
    isoSeqs = parseFastaAsDict(isoFname)
    result = {}
    for id, seq in isoSeqs.iteritems():
        idParts = id.split("|")
        isoName = idParts[1]
        result[isoName] = seq
    logging.info("Found %d isoform sequences" % len(result))
    return result

def writeFaSeqs(faFiles, taxonId, seqs):
    """ write main sequence to faFile with the right taxonId
    base sequence always has accession as ID
    """
    #seqIds = entry.isoIds.split("|")
    #if allVariants:
    if "all" in faFiles:
        ofh = faFiles["all"]
    else:
        ofh = faFiles[taxonId]

    for seqId, seq in seqs:
        ofh.write(">%s\n%s\n" % (seqId.strip(), seq.strip()))

def openFaFiles(taxonIds, outDir, outPrefix):
    faFiles = {}
    if taxonIds == None:
        taxonIds = ["all"]

    for taxonId in taxonIds:
        taxonId = str(taxonId)
        faFname = join(outDir, outPrefix+"."+taxonId+".fa.gz")
        faFiles[int(taxonId)] = gzip.open(faFname, "w")
        logging.debug("Writing fasta seqs for taxon %s to %s" % (taxonId, faFname))
    return faFiles

def stupidXmlFilter(xmlFile, taxonIds):
    " return only the uniprot XML lines that refer to one of the taxon Ids "
    lines = []
    taxonOk = False
    for line in xmlFile:
        if line.startswith("<entry"):
            lines = []
            lines.append(line)
            taxonOk = True # in case of doubt, it's True
        # parse line like <dbReference id="654924" type="NCBI Taxonomy"/>
        elif line.startswith('    <dbReference id="') and line.endswith('type="NCBI Taxonomy"/>\n'):
            recTax = int(line.split('"')[1])
            if taxonIds!=['all'] and recTax not in taxonIds:
                taxonOk = False
            lines.append(line)
        elif line.startswith("</entry"):
            lines.append(line)
            if taxonOk:
                yield lines
            else:
                yield None
            lines = None
        else:
            if taxonOk and lines is not None:
                lines.append(line)

def parseUniprot(db, inDir, outDir, taxonIds):
    " parse uniprot, write records and refs to outdir "

    if options.parse:
        fname = options.parse
        logging.info("Debug parse of %s" % fname)
        xmlFile = open(fname)
        isoSeqs, recCount = {}, 1
        outDir = "."
        outPrefix = "temp"
        disToName = {}
    else:
        isoSeqs = readIsoforms(inDir, db)
        if db=="swissprot":
            xmlBase = "uniprot_sprot.xml.gz"
            outPrefix = "swissprot"
            recCount = 600000
        elif db=="trembl":
            xmlBase = "uniprot_trembl.xml.gz"
            outPrefix = "trembl"
            recCount = 600000*100
        else:
            raise Exception("unknown db")

        xmlBase2 = xmlBase.replace(".gz", "")
        if isfile(xmlBase2):
            logging.debug("Using non-gzipped file %s" % xmlBase2)
            xmlFile = open(join(inDir, xmlBase2))
        else:
            xmlFile = gzip.open(join(inDir, xmlBase))

        logging.info("Parsing main XML file %s" % xmlFile.name)
        disToName = parseDiseases(join(inDir, "docs", "humdisease.txt"))

    faFiles = openFaFiles(taxonIds, outDir, outPrefix)

    logging.debug("Only extracting taxon IDs %s" % str(taxonIds))

    # create a dict taxonId -> output file handles for record info, pmid reference info and annotation info
    outFhs = {}
    for taxId in taxonIds:
        entryOf = openOutTabFile(outDir, "%s.%s.tab" % (outPrefix, taxId), entryHeaders)
        refOf = openOutTabFile(outDir, "%s.%s.refs.tab" % (outPrefix, taxId), refHeaders)
        annotOf = openOutTabFile(outDir, "%s.%s.annots.tab" % (outPrefix, taxId), annotHeaders)
        outFhs[taxId] = (entryOf, refOf, annotOf)

    emptyEntry = dict(zip(entryHeaders, len(entryHeaders)*[""]))

    pm = ProgressMeter(recCount)
    for recLines in stupidXmlFilter(xmlFile, taxonIds):
    # the original solution below did a partial parse, but required 300GB of 
    # RAM
    #for _, entryEl in etree.iterparse(xmlFile):
        #if entryEl.tag!="{http://uniprot.org/uniprot}entry":
            #continue
        #strip_namespace_inplace(entryEl) # die, die stupid namespaces!!

        pm.taskCompleted()
        if recLines is None:
            continue

        entryEl = etree.fromstring("".join(recLines))

        entryTax = int(entryEl.find("organism/dbReference").attrib["id"])

        if taxonIds==['all']:
            taxId = "all"
        else:
            if entryTax not in taxonIds:
                logging.debug("taxon ID %d not a target taxon" % entryTax)
                continue
        entryOf, refOf, annotOf = outFhs[entryTax]

        entry = copy.copy(emptyEntry)
        entryRow, seqs = parseRecInfo(entryEl, entry, isoSeqs)

        writeFaSeqs(faFiles, entryTax, seqs)

        entryOf.write("\t".join(entryRow)+"\n")
        recName = entryRow.name

        refRows = list(parseRefInfo(entryEl, recName))
        for refRow in refRows:
            refOf.write("\t".join(refRow)+"\n")

        annotRecs = parseAnnotations(entryEl, entryRow.mainIsoAcc, disToName)
        for annotRow in annotRecs:
            logging.debug("writing row %s" % str(annotRow))
            annotOf.write("\t".join(annotRow)+"\n")

        # clear some RAM (not all)
        # not needed anymore, since I'm using stupidXmlFilter now
        # https://stackoverflow.com/questions/12160418/why-is-lxml-etree-iterparse-eating-up-all-my-memory
        # remove this if using cElementTree
        #entryEl.clear()
        #for ancestor in entryEl.xpath('ancestor-or-self::*'):
            #while ancestor.getprevious() is not None:
                #del ancestor.getparent()[0]

    logging.info("Skipped annotation types: %s" % ignoredTypes.most_common())

def main(args, options):
    if options.test:
        import doctest
        doctest.testmod()
        sys.exit(0)

    setupLogging("pubParseDb", options)
    db = args[0]

    db = "swissprot"
    if options.trembl:
        db = "trembl"

    dbDir = args[0]

    taxonIds = args[1]
    if taxonIds=="all":
        taxonIds = ['all']
    else:
        taxonIds=[int(x) for x in taxonIds.split(",")]

    refDir = args[2]
    if not isdir(refDir):
        logging.info("Making directory %s" % refDir)
        os.makedirs(refDir)

    if len(args)!=3:
        raise Exception("Invalid command line. Show help with -h")

    parseUniprot(db, dbDir, refDir, taxonIds)

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = optparse.OptionParser("""usage: %prog [options] uniprotFtpDir taxonIds outDir - Convert UniProt to tab-sep files

taxonIds can be "all"

To download uniProt, this command is a good idea:
lftp ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/ -e \
    "mirror . --use-pget-n=10 --exclude-glob *.dat.gz -P 5"

This parser:
- goes over the disease comment evidences and tries to
  classify annotations as disease-related or not.
- resolves disease codes to full disease names
- gets the PMIDs for all evidences
- only gets a limited list of xrefs, but others are easy to add
- can parse Trembl (Who came up with the idea of creating
  a 500GB XML file?)

Example:
%prog /hive/data/outside/uniProt/current 9606 tab/

If you get no results from this script, your species may be only in Trembl.
Use the '--trembl' option to parse UniProt/Trembl instead of UniProt/SwissProt.
Most organisms have entries in both databases and you have to run
the script twice to get all entries.
""")

parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages")
parser.add_option("", "--test", dest="test", action="store_true", help="run tests")
parser.add_option("-p", "--parse", dest="parse", action="store", help="parse a single uniprot xml file (debugging)")
parser.add_option("", "--trembl", dest="trembl", action="store_true", help="parse trembl. Default is to parse only the swissprot files.")
(options, args) = parser.parse_args()

if args==[] and not options.test:
    parser.print_help()
    exit(1)

main(args, options)
