#!/usr/bin/env python

# possible features:
# - add counts per diagnosis

import logging, sys, optparse, gzip
from collections import defaultdict, namedtuple, Counter
from os.path import join, basename, dirname, isfile, walk, splitext
from os import system

# these fields are identical for all variants at the same position. 
# only a single value is output to the bed
fixedFields = [
    "Hugo_Symbol",
    "Entrez_Gene_Id",
    "Variant_Classification",
    "Variant_Type",
    "Reference_Allele",
    "Tumor_Seq_Allele1",
    "Tumor_Seq_Allele2",
    "dbSNP_RS",
    "dbSNP_Val_Status"
]

# these fields vary between variants at the same position, they are output to the BED as a comma-separated list
varFields = [
    "Tumor_Sample_Barcode",
    "Matched_Norm_Sample_Barcode",
    "case_id"
]

# these clinical annotations are summarized as counts
annotCountFields = [
]

# these clinical annotations are copied as a list into every variant
annotListFields = [
    # from annotations.tsv
    "days_to_death",
    # from exposure.tsv
    "cigarettes_per_day",
    "weight",
    "alcohol_history",
    "alcohol_intensity",
    "bmi",
    "years_smoked",
    "height",
    "gender",
    "project_id",
    "ethnicity",
]


# field descriptions for the AS file
fieldDescs = {

}

# ==== functions =====
def parseArgs():
    " setup logging, parse command line arguments and options. -h shows auto-generated help page "
    parser = optparse.OptionParser("""usage: %prog [options] inDir annotationFile exposureFile outBb - summarize GDC MAF files to bigBed

    inDir will be recursively searched for .maf.gz files. Their variants will be extracted, written to outBb

    annotationFile and exposureFile are two .tsv files, both are distributed by the GDC.
    See /hive/data/outside/gdc/pancan33-variants/annotations/ for the versions from Jan 2019


    """)

    parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages")
    #parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
    #parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
    (options, args) = parser.parse_args()

    if args==[]:
        parser.print_help()
        exit(1)

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    return args, options


def findMafs(inDir):
    " recursively find maf.gz files under inDir "
    mafFnames = []

    def gotFiles(arg, dirName, fnames):
        " callback for walk, called for every file "
        for fname in fnames:
            if fname.endswith(".maf.gz"):
                mafFnames.append( join(dirName, fname) )

    walk(inDir, gotFiles, None)
    return mafFnames

def openFile(fname):
    if fname.endswith(".gz"):
        return gzip.open(fname)
    else:
        return open(fname)

def parseTsv(fname, Rec):
    " parse maf file, use rows to fill the Rec (~struct) and yield Recs "
    for line in openFile(fname):
        if line.startswith("#"):
            continue
        if line.startswith("Hugo_Sym"):
            headers = line.rstrip("\n").split("\t")
            assert(tuple(headers)==Rec._fields)
            continue
        row = line.rstrip("\n").split("\t")
        yield Rec(*row)

def makeRec(fname):
    " return the namedtuple rec (=struct) made from first non-comment line in fname "
    for line in openFile(fname):
        if line.startswith("#"):
            continue
        headers = line.rstrip("\n").split("\t")
        MafRec = namedtuple("mafRecord", headers)
        return MafRec

def joinMaxLen(vals, maxLen=255):
    " join strings together. If total length is > maxLen, just take the first 3 and add ... "
    newLen = len(vals[0])*len(vals)
    newStr = ",".join(vals)
    return newStr

def mafsToBed(mafFnames, annots, outBed):
    " summarize mafFnames to outBed "
    byPosAndAllele = defaultdict(list)
    MafRec = None
    allSampleIds = set()
    # read the whole data into memory, index by chrom pos
    for mafName in mafFnames:
        if MafRec is None:
            MafRec = makeRec(mafName)
        logging.info("Reading %s" % mafName)
        lineCount = 0
        for var in parseTsv(mafName, MafRec):
            chrom, start, end = var.Chromosome, var.Start_Position, var.End_Position
            all1, all2 = var.Tumor_Seq_Allele1, var.Tumor_Seq_Allele2
            key = (chrom, start, end, all1, all2)
            byPosAndAllele[key].append(var)
            lineCount += 1
            allSampleIds.add(var.Tumor_Sample_Barcode)
        logging.info("Read %d rows" % lineCount)

    sampleCount = len(allSampleIds)
    logging.info("Read %d variants from %d samples" % (len(byPosAndAllele),sampleCount))

    logging.info("Writing variant summary to %s" % outBed)
    # now summarize all variants for every chrom pos
    longFields = set()
    ofh = open(outBed, "w")
    for (chrom, start, end, all1, all2), varList in byPosAndAllele.iteritems():
        var1 = varList[0]._asdict()
        refAll = var1["Reference_Allele"]
        varType = var1["Variant_Type"]

        if varType=="SNP":
            assert(refAll==all1)
            #if all2=="":
            name = refAll+">"+all2
            #else:
            #name = refAll+">"+all1+"/"+all2
        elif varType=="DEL":
            name = "del"+refAll
        elif varType=="INS":
            name = "ins"+all2
        else:
            invalidType

        varSampleCount = len(varList) # score is number of variants with this nucl change
        score = min(1000, varSampleCount) # don't exceed 1000

        ftLen = str(int(end)-int(start) + 1)
        row = [chrom, str(int(start) - 1), end, name, str(score), ".", str(int(start) - 1), end, "0,0,0", "1", ftLen, "0"]

        # add the count/frequency fields that we derive from the data
        row.append( str(varSampleCount) ) # raw count
        row.append( str( float(varSampleCount) / sampleCount ) ) # frequency

        # a few fields are always the same, so we just add the value from the first variant
        for fieldName in fixedFields:
            row.append(var1[fieldName])

        # some annotations are summarized as counts
        for summCountField in annotCountFields:
            vals = []
            for var in varList:
                annot = annots[var.case_id]
                vals.append(annot[summCountField])

            mostCommon = Counter(vals).most_common()
            mostCommStrings = ["%s:%d" % (name, count) for name, count in mostCommon]
            countStr = ", ".join(mostCommStrings)
            row.append(countStr)

        # some annotations are copied as a long list
        for fieldName in annotListFields:
            vals = []
            for var in varList:
                annot = annots[var.case_id]
                vals.append(annot[fieldName])

            row.append(joinMaxLen(vals))

        # first convert structs to dicts
        dictList = []
        for var in varList:
            dictList.append(var._asdict())

        # some variant values are copied as a long list 
        for fieldName in varFields:
            valList = []
            for d in dictList:
                valList.append(d[fieldName])
            valStr = joinMaxLen(valList)
            row.append(valStr)
            #if len(valStr)>256:
                #longFields.add(fieldName)

        ofh.write("\t".join(row))
        ofh.write("\n")

    ofh.close()
    logging.info("Created file %s" % ofh.name)
    #return longFields

def makeAs(extraFields, outFname):
    " write the .as file to outFname "
    minAs = """table bed12Mafs
"somatic variants converted from MAF files obtained through the NCI GDC"
    (
    string chrom;      "Chromosome (or contig, scaffold, etc.)"
    uint   chromStart; "Start position in chromosome"
    uint   chromEnd;   "End position in chromosome"
    string name;       "Name of item"
    uint   score;      "Score from 0-1000"
    char[1] strand;    "+ or -"
    uint thickStart;   "Start of where display should be thick (start codon)"
    uint thickEnd;     "End of where display should be thick (stop codon)"
    uint reserved;     "Used as itemRgb as of 2004-11-22"
    int blockCount;    "Number of blocks"
    int[blockCount] blockSizes; "Comma separated list of block sizes"
    int[blockCount] chromStarts; "Start positions relative to chromStart"
    string sampleCount;    "number of samples with this variant"
    string freq;    "variant frequency"
    """
    ofh = open(outFname, "w")
    ofh.write(minAs)

    for fieldList in extraFields:
        for fieldName in fieldList:
            fieldType = "lstring"
            #if fieldName in longFields:
                #fieldType = "lstring"
            fieldDesc = fieldDescs.get(fieldName, fieldName)
            ofh.write('    %s %s; "%s"\n' % (fieldType, fieldName, fieldDesc))
    ofh.write(")\n")
    ofh.close()
    logging.info("Wrote %s" % ofh.name)

def bedToBb(outBed, outBedSorted, chromSizes, outAs, outBb):
    " "
    logging.info("Converting to bigBed")
    cmd = "sort -k1,1 -k2,2n %s > %s" % (outBed, outBedSorted)
    assert(system(cmd)==0)

    cmd = "bedToBigBed %s %s  -type=bed12+ -as=%s -tab %s" % \
        (outBedSorted, chromSizes, outAs, outBb)
    assert(system(cmd)==0)
    logging.info("Created %s" % outBb)

def parseAnnots(annotFn):
    " read clinical annotations and return as dict caseId -> namedtuple "
    Rec = makeRec(annotFn)
    ret = {}
    for annot in parseTsv(annotFn, Rec):
        caseId = annot.case_id
        assert(caseId not in ret)
        ret[caseId] = annot
    return ret

# ----------- main --------------
def main():
    args, options = parseArgs()

    inDir, annotFn, expFn, outBb = args

    annots = parseAnnots(annotFn)
    exposures = parseAnnots(expFn)

    # join annots and exposures into a single dict case_id -> field_name -> value
    joinedAnnots = {}
    for key, annot in annots.iteritems():
        ad = annot._asdict()
        ad.update(exposures[key]._asdict())
        joinedAnnots[key] = ad

    mafFnames = findMafs(inDir)

    outBase = splitext(outBb)[0] # strip file extension
    outBed = outBase+".bed"
    outBedSorted = outBase+".s.bed"
    outAs = outBase+".as"

    mafsToBed(mafFnames, joinedAnnots, outBed)

    makeAs([fixedFields, annotCountFields, annotListFields, varFields], outAs)

    chromSizes = "/hive/data/genomes/hg38/chrom.sizes"
    bedToBb(outBed, outBedSorted, chromSizes, outAs, outBb)

main()
