#!/usr/bin/env python

import logging, sys, optparse
from collections import defaultdict
from os.path import join, basename, dirname, isfile

# ==== functions =====
    
def parseArgs():
    " setup logging, parse command line arguments and options. -h shows auto-generated help page "
    parser = optparse.OptionParser("usage: %prog [options] PhenGenFile snpBed - rewrite PhenGen file to a bed8+")

    parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages")
    #parser.add_option("-f", "--file", dest="file", action="store", help="run on file") 
    #parser.add_option("", "--test", dest="test", action="store_true", help="do something") 
    (options, args) = parser.parse_args()

    if args==[]:
        parser.print_help()
        exit(1)

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    return args, options
# ----------- main --------------
def main():
    args, options = parseArgs()

    filename = args[0]
    snpBed = args[1]

    #	Trait	SNP rs	Context	Gene	Gene ID	Gene 2	Gene ID 2	Chromosome	Location	P-Value	Source	PubMed	Analysis ID	Study ID	Study Name
    bySnp = defaultdict(list)
    for line in open(filename):
        if line.startswith("#"):
            continue
        row = line.rstrip("\n").split("\t")

        if row[-2]!="":
            row[-2] = "phs%06d" % int(row[-2])
        if row[2]!="":
            row[2] = "rs"+row[2]

        _, trait, snpId, context, gene, geneId, gene2, geneId2, chrom, start, pVal, source, pmid, analysisId, studyId, studyName = row
        #if source=="NHGRI": # remove GWAS catalog for now?
            #continue
        # need phsxxxxxx for linkout to NCBI
        extraCount = len(row)
        bySnp[snpId].append(row)

    bySnp = dict(bySnp)

    for line in open(snpBed):
        line = line.rstrip("\n")
        snpId = line.rstrip("\n").split('\t')[3]
        chrom, start, end, name = line.split("\t") # has chrom, start, end, name
        bed = [chrom, start, end, name]
        rows = bySnp.get(snpId, None)
        if rows==None:
            continue

        bed.append(str(len(rows)))
        bed.append(".")
        bed.append(start)
        bed.append(end)
        bed.append("0")

        bedCount = len(bed)

        extras = []
        for i in range(1, extraCount):
            extras.append( [] )

        hasGwasCatalog = False
        for row in rows:
            for i in range(1, len(row)):
                #print row, i, len(row)
                extras[i-1].append(row[i].replace(",", ";"))
                _, trait, snpId, context, gene, geneId, gene2, geneId2, chrom, start, pVal, source, pmid, analysisId, studyId, studyName = row

                if source=="NHGRI":
                    hasGwasCatalog = True

        extras = [", ".join(x) for x in extras]

        bed.extend(extras)

        # fixup a few things
        bed[10] = "rs"+bed[10]
        # use trait as name, shorten
        traitShort = bed[9].split(",")[0].split(';')[0].replace(" Disorders", "").replace(" Tests", "")
        if len(traitShort) > 17:
            traitShort = traitShort[:16]+".."
        bed[3] = traitShort

        if hasGwasCatalog:
            bed.append(snpId)
        else:
            bed.append("")

        print "\t".join(bed)

main()
