#!/usr/bin/env python3
# expMatrixToBarchartBed
"""
Generate a barChart bed6+5 file from a matrix, meta data, and coordinates.
"""
import os
import sys
import argparse
import tempfile

def parseArgs(args):
    """
    Parse the command line arguments.
    """
    parser= argparse.ArgumentParser(description = __doc__)
    parser.add_argument ("sampleFile",
        help = " Two column no header, the first column is the samples which should match " + \
                "the matrix, the second is the grouping (cell type, tissue, etc)",
        type = argparse.FileType("r"))
    parser.add_argument ("matrixFile",
            help = " The input matrix file. The samples in the first row should exactly match the ones in " + \
                "the sampleFile. The labels (ex ENST*****) in the first column should exactly match " + \
                "the ones in the bed file.",
        type = argparse.FileType("r"))
    parser.add_argument ("bedFile",
        help = " Bed6+1 format. File that maps the column labels from the matrix to coordinates. Tab " + \
                "separated; chr, start coord, end coord, label, score, strand, gene name. The score " + \
                "column is ignored.",
        action = "store")
    parser.add_argument ("outputFile",
        help = " The output file, bed 6+5 format. See the schema in kent/src/hg/lib/barChartBed.as. ",
        type =argparse.FileType("w"))
    # Optional arguments.
    parser.add_argument ("--autoSql",
        help = " Optional autoSql description of extra fields in the input bed.", action = "store",
        default=None)
    parser.add_argument ("--groupOrderFile",
        help = " Optional file to define the group order, list the groups in a single column in " + \
                "the order desired. The default ordering is alphabetical.",
        action = "store")
    parser.add_argument ("--useMean",
    help = " Calculate the group values using mean rather than median.",
    action = "store_true")
    parser.add_argument ("--verbose",
    help = " Show runtime messages.",
    action = "store_true")

    parser.set_defaults(verbose = False)
    parser.set_defaults(groupOrderFile = None)

    if (len(sys.argv) == 1):
        parser.print_help()
        exit(1)
    options = parser.parse_args()
    return options

def parseExtraFields(autoSqlFile):
    """
    Obtain a list of field names to allow extra fields in the final bed.
    Enforce the bed6+1 format since extra fields must come after the
    expCount, expScores, _dataOffset and _dataLen fields this script will add.
    """
    fields = []
    requiredFieldList = ["chrom", "chromStart", "chromEnd", "name", "score", "strand", "name2",
                        "expCount", "expScores", "_dataOffset", "_dataLen"]
    with open(autoSqlFile, "r") as f:
        for inputLine in f.readlines():
            line = inputLine.strip().split()
            if line[0].startswith("table") or line[0].startswith("\"") or line[0].startswith("("):
                continue
            elif line[0].startswith(")"):
                return fields
            else:
                clean = line[1].rstrip(";")
                if clean in requiredFieldList:
                    if len(fields) == requiredFieldList.index(clean):
                        fields.append(clean)
                    else:
                        raise Exception("autoSql fields out of order, must be in %s, extraFields order\n" % \
                            ", ".join(requiredFieldList))
                else:
                    if len(fields) >= 11:
                        fields.append(clean)
                    else:
                        raise Exception("first 11 barChart fields must be %s\n" % ", ".join(requiredFieldList))

def median(lst):
    lst = sorted(lst)
    if len(lst) < 1:
        return None
    if len(lst) %2 == 1:
        return lst[((len(lst)+1)/2)-1]
    else:
        return float(sum(lst[(len(lst)/2)-1:(len(lst)/2)+1]))/2.0

def determineScore(tpmCutoffs, tpm):
    """
    Cast the tpm to a score between 0-1000. Since there are only
    9 visual blocks cast them to be in one of the 9 blocks.
    tpmCutoffs - A list of integers
    tpm - An integer
    """
    count = 0
    for val in tpmCutoffs:
        if (val > tpm):
            return count*111
        count = count + 1
    return 999

def condenseMatrixIntoBedCols(matrix, groupOrder, autoSql, sampleToGroup, validTpms, bedLikeFile, useMean):
    """
    Take an expression matrix and a dictionary that maps the samples to groups.
    Go through the expression matrix and calculate the average for each group, outputting
    it to an intermediate file as they are calculated. The intermediate file has
    three columns, the first is the average tpm for the entire gene, next is the
    number of groups and finally the average tpm for each group as a comma separated list.

    matrix - An expression matrix, samples are the x rows, transcripts the y rows.
    groupOrder - Optional order of categories
    autoSql - Optional description of extra fields
    sampleToGroup - A dictionary that maps string samples to string groups.
    validTpms - An empty list of integers.
    bedLikeFile - An intermediate file, looks slightly like a bed.
    """
    # Store some information on the bed file, most important is the order
    # of the 8th column.
    bedInfo = ""
    firstLine = True
    getBedInfo = True
    # Use the first line of the matrix and the sampleToGroup dict to create a dictionary that maps
    # the column to a group.
    columnToGroup = dict()

    # Go through the matrix line by line. The first line is used to build an index mapping columns
    # to group blocks, then for each line with TPM values merge the values based on group blocks.
    for line in matrix:
        splitLine = line.strip("\n").split("\t")

        # The first line is the word 'transcript' followed by a list of the sample names.
        if firstLine:
            firstLine = False
            count = 1
            firstCol = True
            for col in splitLine:
                if firstCol:
                    firstCol = False
                    continue
                group = sampleToGroup[col]
                columnToGroup.setdefault(count, group)
                count += 1
            continue

        # Handle the tpm rows, calculating the average for each group per row.
        groupAverages = dict()
        groupCounts = dict()
        firstCol = True
        count = 1
        for col in splitLine:
            if firstCol:
                firstCol = False
                continue
            if (useMean):
                # First time this group is seen, add it to the groupCounts dict.
                if (groupAverages.get(columnToGroup[count]) == None):
                    groupAverages.setdefault(columnToGroup[count], float(col))
                    groupCounts.setdefault(columnToGroup[count], 1)
                # This group has already been seen, update the TPM average.
                else:
                    groupCounts[columnToGroup[count]] += 1
                    # Average calculation
                    normal = float(groupCounts[columnToGroup[count]])
                    newTpm = (float(col) * (1/normal))
                    oldTpm = (((normal - 1) / normal) * groupAverages[columnToGroup[count]])
                    groupAverages[columnToGroup[count]] = newTpm + oldTpm
            else:
                # First time this group is seen, add it to the groupCounts dict.
                if (groupAverages.get(columnToGroup[count]) == None):
                    groupAverages.setdefault(columnToGroup[count], [float(col)])
                    groupCounts.setdefault(columnToGroup[count], 1)
                # This group has already been seen, update the TPM average.
                else:
                    groupCounts[columnToGroup[count]] += 1
                    # Median preparation
                    groupAverages[columnToGroup[count]].append(float(col))

            count += 1

        # Store some information on the bed file. Most important is the groupOrder.
        if getBedInfo:
            getBedInfo = False
            groups = ""
            bedInfo += "#chr\tchromStart\tchromEnd\tname\tscore\tstrand\tname2\texpCount\texpScores;"
            if (groupOrder is not None):
                for group in open(groupOrder, "r"):
                    groups += group.strip("\n") + " "
            else:
                for key, value in sorted(groupAverages.items()):
                    groups += key + " "
            if autoSql and len(autoSql) != 11: # parseExtraFields requires first 11 fields to be standard
                bedInfo += groups[:-1] + "\t_offset\t_lineLength\t" + "\t".join(autoSql[11:])
            else:
                bedInfo += groups[:-1] + "\t_offset\t_lineLength"

        # Write out the transcript name, this is needed to join with coordinates later.
        bedLikeFile.write(splitLine[0] + "\t")
        # Create a list of the average scores per group.
        bedLine = ""
        # The fullAverage is used to assign a tpm score representative of the entire bed row.
        fullAverage = 0.0
        count = 0.0
        if (groupOrder is not None):
            for group in open(groupOrder, "r"):
                # Averages
                if (useMean):
                    value = groupAverages[group.strip("\n")]
                else:
                    value = median(groupAverages[group.strip("\n")])
                bedLine = bedLine + "," + "%0.2g" % value
                count += 1.0
                fullAverage += value
        else:
            for key, value in sorted(groupAverages.items()):
                if (useMean):
                    bedLine = bedLine + "," + "%0.2g" % value
                    fullAverage += value
                else:
                    bedLine = bedLine + "," + "%0.2g" % median(value)
                    fullAverage += median(value)

                count += 1.0
        # Create what will be columns 5, 7 and 8 of the final bed.
        bedLine = str(fullAverage/count) + "\t" + str(int(count)) +  "\t" + bedLine[1:] + "\n"
        # If the fullAverage tpm is greater than 0 then consider it in the validTpm list.
        if (fullAverage > 0.0):
            validTpms.append((fullAverage/count))
        # Write the bedLine to the intermediate bed-like file.
        bedLikeFile.write(bedLine)
    # Return the bedInfo so it can be printed right before the script ends.
    return bedInfo

def expMatrixToBarchartBed(options):
    """
    Convert the expression matrix into a barchart bed file.
    options - The command line options (file names, etc).
        Use the meta data to map the sample names to their groups, then create a dict that maps
    the columns to the groups.  Go through the matrix line by line and get the median or average for each
    group. Print this to an intermediate file, then use the unix 'join' command to link with
    the coordinates file via the first matrix column. This creates a file with many of the bed fields
    just in the wrong order.  Go through this file to re arrange the columns, check for and remove entries
    where chromsomes names include "_" and chr start > chr end.  Finally run Max's bedJoinTabOffset to
    index the matrix adding the dataOffset and dataLen columns and creating a bed 6+5 file.
    """

    # Create a dictionary that maps the sample names to their group.
    sampleToGroup = dict()
    count = 0
    for item in options.sampleFile:
        count +=1
        splitLine = item.strip("\n").split("\t")
        if (len(splitLine) != 2):
            print ("There was an error reading the sample file at line " + str(count))
            exit(1)
        sampleToGroup.setdefault(splitLine[0], splitLine[1])

    autoSql = parseExtraFields(options.autoSql) if (options.autoSql) else None

    # Use an intermediate file to hold the average values for each group, format: name,mean/median,expCount,expScores
    bedLikeFile = tempfile.NamedTemporaryFile( mode = "w+")
    # Keep a list of TPM scores greater than 0. This will be used later
    # to assign bed scores.
    validTpms = []

    # Go through the matrix and condense it into a bed like file. Populate
    # the validTpms array and the bedInfo string.
    bedInfo = condenseMatrixIntoBedCols(options.matrixFile, options.groupOrderFile, autoSql, sampleToGroup, \
                    validTpms, bedLikeFile, options.useMean)
    # Find the number which divides the list of non 0 TPM scores into ten blocks.
    tpmMedian = sorted(validTpms)
    blockSizes = len(tpmMedian)/10

    # Create a list of the ten TPM values at the edge of each block.
    # These used to cast a TPM score to one of ten value between 0-1000.
    tpmCutoffs = []
    for i in range(1,10):
        tpmCutoffs.append(tpmMedian[blockSizes*i])

    # Sort the bed like file to prepare it for the join.
    sortedBedLikeFile = tempfile.NamedTemporaryFile( mode = "w+")
    cmd = "sort -k1 " + bedLikeFile.name + " > " + sortedBedLikeFile.name
    os.system(cmd)

    # Sort the coordinate file to prepare it for the join.
    sortedCoords = tempfile.NamedTemporaryFile( mode = "w+")
    cmd = "sort -k4 " + options.bedFile + " > " + sortedCoords.name
    os.system(cmd)

    # Join the bed-like file and the coordinate file, the awk accounts for any extra
    # fields that may be included, and keeps the file in standard bed 6+5 format
    joinedFile = tempfile.NamedTemporaryFile(mode="w+")
    cmd = "join -t '	' -1 4 -2 1 " + sortedCoords.name + " " + sortedBedLikeFile.name + \
            " | awk -F'\\t' -v OFS=\"\\t\" '{printf \"%s\\t%s\\t%s\\t%s\", $2,$3,$4,$1; " + \
            "for (i=5;i<=NF;i++) {printf \"\\t%s\", $i}; printf\"\\n\";}' > " + joinedFile.name
    os.system(cmd)

    # Go through the joined file and re arrange the columns creating a bed 6+5+ file.
    # Also assign a scaled score 0 - 1000 to each tpm value.
    bedFile = tempfile.NamedTemporaryFile(mode="w+")
    for line in joinedFile:
        splitLine = line.strip("\n").split("\t")
        # Drop sequences where start is greater than end.
        if (float(splitLine[1]) > float(splitLine[2])):
            sys.stderr.write("This transcript: " + splitLine[0] + " was dropped since chr end, " + \
                    splitLine[2] + ", is smaller than chr start, " + splitLine[1] + ".\n")
            continue
        score = str(determineScore(tpmCutoffs, float(splitLine[-3])))
        if autoSql:
            #skip the 4th field since we recalculated it
            #need a different ordering to account for possible extraFields
            bedLine = "\t".join(splitLine[:4] + [score] + splitLine[5:7] + splitLine[-2:] + splitLine[7:-3]) + "\n"
        else:
            #skip the 4th field since we recalculate it
            bedLine = "\t".join(splitLine[:4] + [score] + splitLine[5:7] + splitLine[8:]) + "\n"

        bedFile.write(bedLine)

    # Run Max's indexing script: TODO: add verbose options
    indexedBedFile = tempfile.NamedTemporaryFile(mode="w+")
    cmd = "bedJoinTabOffset " + options.matrixFile.name + " " + bedFile.name + " " + indexedBedFile.name
    if not options.verbose: cmd += " 2>&1 >/dev/null"
    os.system(cmd)

    # Prepend the bed info to the start of the file.
    cmd = "echo '" + bedInfo + "' > " + options.outputFile.name
    os.system(cmd)

    # any extra fields must come after the fields added by bedJoinTabOffset
    if autoSql:
        reorderedBedFile = tempfile.NamedTemporaryFile(mode="w+")
        # first print the standard bed6+3 barChart fields
        # then print the two fields added by bedJoinTabOffset
        # then any extra fields at the end:
        cmd = "awk -F'\\t' -v \"OFS=\\t\" '" + \
                "{for (i = 1; i < 10; i++) {if (i > 1) printf \"\\t\"; printf \"%s\", $i;}; " + \
                "for (i = NF-1; i <= NF; i++) {printf \"\\t%s\", $i;} " + \
                "for (i = 10; i < NF - 1; i++) {printf \"\\t%s\", $i;} " + \
                "printf \"\\n\";}' " + indexedBedFile.name + " > " + reorderedBedFile.name
        os.system(cmd)
        cmd = "cat " + reorderedBedFile.name + " >> " + options.outputFile.name
        os.system(cmd)
    else:
        cmd = "cat " + indexedBedFile.name + " >> " + options.outputFile.name
        os.system(cmd)

    if options.verbose: print ("The columns and order of the groups are; \n" + bedInfo)

def main(args):
    """
    Initialized options and calls other functions.
    """
    options = parseArgs(args)
    expMatrixToBarchartBed(options)

if __name__ == "__main__":
    sys.exit(main(sys.argv))
