#!/usr/bin/env python2.7

# a converter for kallisto output files
# parses ~800 cells in 1-2 minutes
# and merges everything into a single matrix that can be read with one line in R
# also outputs a binary hash files that can be read from python with one line

import logging, sys, optparse, gzip, gc
from collections import defaultdict
from os.path import join, basename, dirname, isfile, isdir, splitext
import os, marshal, glob
from multiprocessing.dummy import Pool
import time

# no garbage collection needed in here
gc.disable()

# the column of the kallisto output file with the TPM value
tpmColumnIndex = 4

# keeps genes with only 0 values, uses more memory but then the output line count is always the same
addZeros = False

# skip format checks for gene and transcript IDs
notEnsembl = False

# === command line interface, options and help ===
def parseArgs():
    parser = optparse.OptionParser("""usage: %prog [options] dirName outBase
    collect expression counts from all abundance.tsv files,
    given a directory containing kallisto output directories or kallisto output files.

    Input files that end with "*.abundance.tsv" can be located either 
    directly in dirName or one level beneath it.

    If the input is directories, the directory name is the cell ID.
    If the input is files, the part before the first dot is the cell ID.

    TPM transcript counts are written to <outBase>.tpm.tab
    TPM gene counts are written to <outBase>.geneTpm.tab

    Optional: Parses dirName/log/*.log files with kallisto output and creates a .tab file
    that contains two basic stats: number of reads per file and number of
    reads mapped to transcript models, written to <outBase>.stat.tab
    Can be used for tagStormJoinTab.
    It is assumed that the part before the first dot in the log filename
    is the identifier of the cell. If available, the part before the
    underscore in this filename is the name of the sample ("meta").

    Ctrl-c does not work for this program. Do ctrl-z and then 'kill %%'.

    Example:
    kallistoToMatrix /hive/groups/cirm/submit/quake/quakeBrainGeo1/kallistoOut quakeBrainGeo1
    """)

    parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages")
    parser.add_option("-t", "--threads", dest="threads", action="store", type="int", help="number of threads to use, default %default", default=30)
    parser.add_option("-s", "--addZeros", dest="addZeros", action="store_true", help="keep genes with only 0 values, makes sure that the matrix rowcount is always the same")
    parser.add_option("-r", "--recursive", dest="recursive", action="store_true", help="search in all subdirectories for files that end with *abundance.tsv")
    parser.add_option("",  "--notEnsembl", dest="notEnsembl", action="store_true", help="Do not check transcript or gene identifiers for Ensembl format.")
    parser.add_option("", "--sym", dest="addSymbols", action="store_true", help="append symbols after gene ID, separate with '/'")
    parser.add_option("", "--transFile", dest="transFile", action="store", help="the text file with the mapping transcript-gene-symbol (sym is not used), one per line. default %default", default="/hive/data/outside/gencode/release_22/transToGene.tab")
    (options, args) = parser.parse_args()
    if args==[]:
        parser.print_help()
        exit(1)


    if options.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)

    return args, options

# ==== functions =====
    
def findInputFiles(dirname, doRecurse):
    """ return two lists: cellIds and all paths of files */abundance.tsv in subdirs under dirname 
    subDir name is the cellId.

    alternative input:
    dirName can also be a directory full of *.abundance.tsv files where * is the cellId.
    """
    logging.info("Looking for *.abundance.tsv files in %s" % dirname)
    cellIds = []


    if doRecurse:
        inPaths = []
        logging.info("Looking in all subdirectories")
        for dirpath, subdirs, files in os.walk(dirname):
            for fname in files:
                if fname.endswith("abundance.tsv"):
                    inPaths.append(join(dirpath, fname))
    else:
        inPaths = glob.glob(join(dirname, "*"))

    filePaths = []
    for inputPath in inPaths:
        # check if it's an abundance file
        if inputPath.endswith(".abundance.tsv") and isfile(inputPath):
            cellId = basename(inputPath).split(".")[0]
            filePath = inputPath
        else:
            # check if it's a dir with the right filename in it
            if not isdir(inputPath):
                continue
            filePath = join(inputPath, "abundance.tsv")
            if not isfile(filePath):
                continue
            cellId = basename(inputPath)
        cellIds.append(cellId)
        filePaths.append(filePath)
    logging.info("Found %d input files" % len(filePaths))
    return cellIds, filePaths

def parseKallisto(fname):
    """ parse a kallisto abundance.tsv file, return dict transcriptId -> est_tpm 
    Does not return a value for transcripts where est_tpm is 0
    """

    logging.debug("parsing %s" % fname)
    ifh = open(fname)
    ifh.readline()

    d = {}
    for line in ifh:
        fs = line.rstrip("\n").split("\t")
        if fs[tpmColumnIndex]=="0" and not addZeros:
            continue
        d[fs[0]] = float(fs[tpmColumnIndex])
    return d
        
def parseGenes(fname, doAddSyms, stripDot=False):
    """ open (gzip) file with key-val one per line and return as dict 
    if doAddSyms is true, append the symbol to the gene ID
    """
    if fname.endswith(".gz"):
        ifh = gzip.open(fname)
    else:
        ifh = open(fname)

    ret = {}
    for line in ifh:
        if line.startswith("#") or line.startswith("transcriptId"):
            continue
        fs = line.rstrip("\n").split()
        transId = fs[0]
        geneId = fs[1]
        sym = fs[2]
        assert(notEnsembl or ("T" in transId)) # Ensembl transcript identifiers always have a T in them
        assert(notEnsembl or ("G" in geneId))  # Ensembl gene identifiers always have a G in them
        if stripDot:
            transId = transId.split('.')[0]
            geneId = geneId.split('.')[0]
        assert("/" not in sym)
        ret[transId] = geneId+"/"+sym
    return ret
        
def outputBigMatrix(cellNames, results, outFname, isGene=False):
    """
    given a list of cellNames and a list of transcript -> count dictionaries,
    write out a matrix with transcript -> counts in columns
    """
    logging.info("Writing data to file %s" % outFname)
    ofh = open(outFname, "w")
    # write header
    if isGene:
        ofh.write("#gene\t%s\n" % "\t".join(cellNames))
    else:
        ofh.write("#transcript\t%s\n" % "\t".join(cellNames))
    
    # create a sorted list of all transcript or gene IDs
    logging.info("Getting transcript IDs")
    allTrans = set()
    for res in results:
        allTrans.update(res)
    allTrans = list(allTrans)
    allTrans.sort()

    # write out matrix
    logging.info("Iterating over transcript IDs and writing to tab file")
    for trans in allTrans:
        ofh.write("%s\t" % trans)
        row = []
        for countDict in results:
            row.append(str(countDict.get(trans, 0)))
        ofh.write("\t".join(row))
        ofh.write("\n")
    ofh.close()

    # also output as a binary file for now
    # it's a lot easier and faster to parse, at least for python scripts
    # can be read from python with a single line:
    # matrix = marshal.load(open("data.tab.marshal"))
    # matrix is then a nested hash: cellName -> transcriptID or geneID -> count
    binPath = outFname+".marshal"
    logging.info("Writing %s" % binPath)
    allData = {}
    for name, transDict in zip(cellNames, results):
        allData[name] = transDict
    marshal.dump(allData, open(binPath, "wb"))
        
# this was a try to write sparse matrices. 
# it results in smaller data files, but they are harder to compress with gzip
# overall the non-sparse matrices were smaller on the quake dataset

#def outputBigMatrixSparse(names, results, outFname):
    #logging.info("Writing data to file %s" % outFname)
    #ofh = open(outFname, "w")
    
    # create set of all transcript names
    #allTrans = set()
    #for res in results:
        #allTrans.update(res)

    # create dict with transcript -> number
    # and write out this mapping
    #ofh.write("#")
    #transToId = {}
    #for i, trans in enumerate(allTrans):
        #transToId[trans] = i
        #ofh.write("#%s=%d\n" % (trans, i))

    # write out matrix
    #for cellId, transDict in zip(names, results):
        #ofh.write("%s\t" % cellId)
        #strList = ["%d=%f" % (transToId[trans], val) for trans, val in transDict.iteritems()]
        #ofh.write(",".join(strList))
        #ofh.write("\n")
    #ofh.close()
    
def sumTransToGene(transDictList, transFile, doAddSyms):
    """ given a list of dict transcript -> tpm, and a map transcript -> gene,
    map all transcripts to genes and return a list of gene -> sum of tpms
    If we have no gene ID, drop the transcript entirely.
    """
    transToGene = parseGenes(transFile, doAddSyms, stripDot=True)
    logging.info("Mapping %d transcript IDs to gene IDs" % len(transToGene))

    newRes = []
    noMapTransIds = set()
    for transCounts in transDictList:
        geneCounts = defaultdict(float)
        for transId, count in transCounts.iteritems():
            transId = transId.split(".")[0]
            geneId = transToGene.get(transId)
            if geneId is None:
                noMapTransIds.add(transId)
            else:
                geneCounts[geneId]+=count
        newRes.append(dict(geneCounts))

    logging.info("no gene ID found for %d transcript IDs. These are probably scaffolds/patches/etc. Example IDs: %s" %
        (len(noMapTransIds), list(noMapTransIds)[:5]))
    return newRes
    
def writeStats(inDir, outFname):
    """
    search for all .log files in inDir/log. Use the basename of these files
    as the cell ID and write a .tab file that can be joined with tagStormJoinTab
    """
    logDir = join(inDir, "log")
    if not isdir(logDir):
        logging.info("Cannot find %s, not parsing kallisto log files" % logDir)
        return

    ofh = open(outFname, "w")
    #ofh.write("meta\tkallistoProcReads\tkallistoAlnReads\tkallistoEstFragLen\n")
    ofh.write("cellId\tmeta\tprocReads\talnReads\testFragLen\n")

    inFnames = glob.glob(join(logDir, "*.log"))
    print("Parsing %d logfiles and writing to %s" % (len(inFnames), outFname))
    for inFname in inFnames:
        cellId = basename(inFname).split(".")[0]
        meta = cellId.split("_")[0]
        # [quant] processed 1,836,518 reads, 636,766 reads pseudoaligned
        # [quant] estimated average fragment length: 251.99
        for line in open(inFname):
            if line.startswith("[quant] processed "):
                words = line.split()
                readCount = words[2].replace(",","")
                alignCount = words[4].replace(",","")
            if line.startswith("[quant] estimated average fragment length:"):
                fragLen = line.split()[5]
        row = [cellId, meta, readCount, alignCount, fragLen]
        ofh.write("\t".join(row)+"\n")
    ofh.close()

def main():
    args, options = parseArgs()
    inDir = args[0]
    outBase = args[1]

    global addZeros
    if options.addZeros:
        addZeros = True

    global notEnsembl
    if options.notEnsembl:
        notEnsembl = True

    transOutFname = outBase+".tpm.tab"
    statOutFname = outBase+".stat.tab"

    writeStats(inDir, statOutFname)

    cellNames, inFnames = findInputFiles(inDir, options.recursive)
    assert(len(cellNames)==len(inFnames))

    if len(inFnames)==0:
        logging.info("No input files found")
        sys.exit(1)
    logging.info("Found %d kallisto input files (*.abundance.tsv)" % len(inFnames))

    threadCount = options.threads
    logging.info("Using %d threads to parse input files" % threadCount)

    # multithreading in 4 lines
    pool = Pool()
    # chunksize=1 lowers performance but means that the progress status is accurate
    results = pool.map_async(parseKallisto, inFnames, chunksize=1) 
    #results = pool.map(parseKallisto, inFnames)

    # copied from https://gist.github.com/ekrause/e30739befdd4ab7b661d
    # this hack makes the async_map respond to ^C interrupts
    #try:
        #pool.get(0xFFFF)
        #monitor_thread.join()

    #except KeyboardInterrupt:
        #print 'parent received control-c'
        #exit()


    while not results.ready():
        print("jobs left: {}".format(results._number_left))
        time.sleep(1)

    pool.close()
    pool.join()
    results = results.get()

    logging.info("Done reading input files.")

    outputBigMatrix(cellNames, results, transOutFname)

    geneOutFname = outBase+".geneTpm.tab"
    results = sumTransToGene(results, options.transFile, options.addSymbols)

    outputBigMatrix(cellNames, results, geneOutFname, isGene=True)

    logFh = open(outBase+".log", "w")
    logFh.write("inputDir\t%s\n" % inDir)
    logFh.write("geneModelFname\t%s\n" % options.transFile)
    logFh.write("addZeros\t%s\n" % options.addZeros)
    logFh.close()
    logging.info("Wrote parameters to %s" % logFh.name)

main()
