#!/usr/bin/env python2.7

# Max Haeussler, 6/2007, 6/2014

from sys import *
import sys
from optparse import OptionParser
import logging

# === COMMAND LINE INTERFACE, OPTIONS AND HELP ===
parser = OptionParser("usage: %prog [options] inSortedBedFile outfile (stdout ok): given sorted feats., create features between them and annotated these with the neighboring bednames. Regions around chromosomes are limited to 50kbp.")

parser.add_option("-d", "", dest="debug", action="store_true", \
        help="show debug messages")
parser.add_option("-s", "--chromSizes", dest="chromSizes", action="store", \
        help="use this file with chrom <tab> size - lines to get chrom sizes") 
parser.add_option("-u", "--upstream", dest="upstream", action="store_true", \
        help="only report intergenic regions that are upstream (if between -/+ neighbors, two features are created)", default=False)
parser.add_option("-l", "--limit", dest="limit", action="store", \
        help="if -u is set: limit the length of upstream regions to this size", default=0, type="int")
parser.add_option("-m", "--maxLen", dest="maxLen", action="store", \
        help="limit names of features to this many characters in output bed file", default=None, type="int")
parser.add_option("", "--uniqueNames", dest="uniqueNames", action="store_true", \
        help="keep only the first feature, if identical names")
parser.add_option("-a", "--all", dest="all", action="store_true", \
        help="output all features (input and spacers between them) and annotate with 'ex:', 'in:', 'ig:' (exon, intron, intergenic)")

(options, args) = parser.parse_args()

doAll = False

maxNameLen = None

# ==== FUNCTIONs =====

def truncSize(start, end, limit, onLeft):
    """ if limit is true, cut down start end to max limit length and return as pair """
    if limit!=0 and end-start > limit:
        if onLeft:
                return (start, start+limit)
        else:
                return (end-limit, end)
    return start, end

def writeBeds(outf, chrom, posList, names, type):
    for i in range(0, len(names)):
        name = names[i]
        if maxNameLen is not None and len(name)>maxNameLen:
            name = name[:maxNameLen-3]+"..."

        start, end = posList[i]
        # skip 0-length features
        if start==end:
            continue
        outf.write("%s\t%d\t%d\t%s\n" % (chrom, start, end, name))

def output(fh, left, right, upstream, limit, type, chromSizes=None):
    """ given two flanking beds, create lists of names for the regions between and write to fh"""
    """ this can return 0, 1 or 2 different names.
    type can be "ig", "ex" or "in" = intergenic, exon, intron
    """

    global outCount
    global dublCount
    names = []
    posList = []

    # basic info
    if left==None:
        chrom = right.chrom
        start = 0
        end = right.start
        leftname = "Gap"
        rightname = right.name
    elif right==None:
        chrom = left.chrom
        start = left.end
        end = chromSizes[left.chrom]
        leftname = left.name
        rightname = "Gap"
    else:
        chrom = left.chrom
        start = left.end
        end = right.start
        leftname = left.name
        rightname = right.name

    if not upstream:
        posList = [ (start, end) ]
        if doAll:
            if type=="ig":
                names = ["ig:"+leftname+"|"+rightname]
            elif type=="in":
                assert(leftname==rightname)
                names = ["in:"+leftname]
        else:
            names = [leftname+"|"+rightname]
    else:
        # + everything that depends on strands (-u option)
        if left!=None and right!=None and left.strand=="-" and right.strand=="+":
            start1, end1 = truncSize(start, end, limit, True)
            start2, end2 = truncSize(start, end, limit, False)
            posList = [(start1,end1), (start2, end2)]
            names = [leftname, rightname]
            dublCount+=1
        elif left!=None and left.strand=="-":
            start, end = truncSize(start, end, limit, True)
            posList = [ (start, end) ] 
            names = [leftname]
        elif right!=None and right.strand=="+":
            start, end = truncSize(start, end, limit, False)
            posList = [ (start, end) ] 
            names = [rightname]
        else:
            return 

    writeBeds(fh, chrom, posList, names, type)
    outCount += len(names)
    lastLeft = left

def slurpdict(fname):
    """ read in tab delimited file """
    dict = {}
    for l in open(fname, "r"):
        l = l.strip("\n")
        fs = l.split("\t")
	if not len(fs)>1:
            continue
        key = fs[0]
        val = int(fs[1])
        if key not in dict:
            dict[key] = val
        else:
            sys.stderr.write("info: hit key %s two times: %s -> %s\n" %(key, key, val))
    return dict

# --- generic bed stuff, from bed.py -----

def coordOverlap(start1, end1, start2, end2):
    """ returns true if two Features overlap """
    result = (( start2 <= start1 and end2 > start1) or \
            (start2 < end1 and end2 >= end1) or \
            (start1 >= start2 and end1 <= end2) or \
            (start2 >= start1 and end2 <= end1))
    return result

def hasOverlap(f1, f2):
    """ returns true if two Features overlap """
    if f1.chrom != f2.chrom:
        return False
    result = coordOverlap(f1.start, f1.end, f2.start, f2.end)
    return result

def parseBedFilename(fname, reqSorted=False, quiet=False):
    if not quiet:
	sys.stderr.write("Reading %s...\n" % fname)
    if fname=="stdin":
        return parseBedFile(sys.stdin, reqSorted)
    else:
        return parseBedFile(open(fname,"r"), reqSorted)

class Features(list):
    def __repr__(self):
        buf = []
        for i in self:
            buf.append(repr(i))
        return "\n".join(buf)
    def __sort__(self):
        self.sort(Feature.sort)

    def writeToFileHandle(self, fh):
        for b in self:
            fh.write(str(b)+"\n")

    def writeToFile(self, fname):
        if fname!="stdout":
            fh = open(fname, "w")
        else:
            fh = sys.stdout
        self.writeToFileHandle(fh)
        fh.close()

    def countChromosomes(self):
        chroms = set()
        for b in self:
            chroms.add(b.chrom)
        return len(chroms)

class Feature:
    def __init__(self, line):
        l = line.split()
        count = len(l)
        self.chrom = l[0]
        self.start = int(l[1])
        self.end = int(l[2])
        if count >= 4:
            self.name = l[3]
        if count >= 5:
            self.strand= l[4] 
        if count >= 6:
            self.score = int(l[5])

    def __init__(self,seqid="",start=0,end=0,name="",score=0,strand="+"):
        self.chrom = seqid
        self.start = int(start)
        self.end = int(end)
        self.name = name
        self.score = int(score)
        self.strand = strand

    def __repr__(self):
        fields = [self.chrom,str(self.start),str(self.end)]
        if "name" in self.__dict__:
            fields.append(self.name)
        if "score" in self.__dict__:
            fields.append(str(self.score))
        if "strand" in self.__dict__:
            fields.append(self.strand)
        if "blockStarts" in self.__dict__:
            fields.append(str(self.thickStart))
            fields.append(str(self.thickEnd))
            fields.append(self.itemRgb)
            fields.append(str(self.blockCount))
            fields.append(self.blockSizes)
            fields.append(self.blockStarts)

        return "\t".join(fields)

    def includes(self, f, tol=0):
        if self.chrom!=f.chrom:
            return False
        else:
            if self.start-tol <= f.start and self.end+tol >= f.end:
                return True

    def overlaps(self, f1):
        """ returns true if two Features overlap """
        return hasOverlap(self,f1)

def parseBedLine(line):
    f = Feature()
    fields = line.split("\t")
    f.chrom = fields[0]
    f.start = int(fields[1])
    f.end = int(fields[2])
    if len(fields)>3:
        f.name = fields[3]
    if len(fields)>4:
        f.score = int(fields[4])
    if len(fields)>5:
        f.strand = fields[5]
    if len(fields)>6:
        f.thickStart=int(fields[6])
        f.thickEnd=int(fields[7])
        f.itemRgb=fields[8]
        f.blockCount=int(fields[9])
        f.blockSizes=fields[10]
        f.blockStarts=fields[11]
    return f

def parseBedFile(lines, reqSort=False):
    """ return a Features() object """
    features = Features()
    lastStart = -1
    for l in lines:
        l = l.strip()
        if l.startswith("track") or l.startswith("browser") or l.startswith("#") or l=="":
            continue
        f = parseBedLine(l)
        if reqSort and lastStart > f.start:
            sys.stderr.write("error: bed file is not sorted, violating feature: %s" % str(f))
            sys.exit(1)
        features.append(f)
    return features
# ----------- MAIN --------------
if args==[]:
    parser.print_help()
    exit(1)

if options.debug:
    logging.basicConfig(level=logging.DEBUG)
if options.all:
    doAll = True

# parse input arguments
infile = args[0]
outfile = args[1]
upstream = options.upstream
limit = options.limit
chromSize=None
if options.chromSizes!=None:
    chromSize = slurpdict(options.chromSizes)

maxNameLen = options.maxLen

stderr.write("Reading beds ...\n")
beds = parseBedFilename(infile)
stderr.write("%d features read\n" % len(beds))

if options.uniqueNames:
    newBeds = []
    seenNames = set()
    for b in beds:
        if b.name not in seenNames:
            newBeds.append(b)
        seenNames.add(b.name)
    logging.info("Removing features with duplicate names, feature count %d -> %d" % \
        (len(beds), len(newBeds)))
    beds = newBeds
            
# open output file 
if outfile=="stdout":
    outf = stdout
else:
    outf = open(outfile, "w")

# init stats counters
sizeMissing=0
outCount = 0
dublCount = 0
lastLeft=None
included = 0
# output flank around first feature
output(outf, None, beds[0], upstream, limit, "ig", chromSize)
# these are the main iterator variables for the following loop
left = None
right = None

# iterate over all bed features from input file
for i in range(0, len(beds)-1):
    left = beds[i]
    right = beds[i+1]

    # generate features "on the edges" and break where chrom ends or starts
    if left.chrom != right.chrom:
        if chromSize!=None:
            output(outf, left, None, upstream, limit, "ig", chromSize)
        else:
            sizeMissing+=1
        output(outf, None, right, upstream, limit, "ig", chromSize)
        lastLeft = left
        continue

    assert(not left.overlaps(right)) # features cannot overlap

    # sanity check
    if left.start > right.start:
        stderr.write("error: bed file does not seem to be sorted\n")
        stderr.write("check these features: %s \n<->\n %s\n" % (left, right))
        exit(1)

    if left.name==right.name:
        rangeType = "in"
    else:
        rangeType = "ig"
    output(outf, left, right, upstream, limit, rangeType, chromSize)

    if doAll:
        outf.write("%s\t%d\t%d\tex:%s\n" % (left.chrom, left.start, left.end, left.name))

    lastLeft = left
    
# write some overall statistics to stderr
stderr.write("%d features written\n" % outCount)
stderr.write("%d genes/exons were dropped because one includes the other\n" % included)
if chromSize==None:
    stderr.write("%d features are missing because of missing chromSizes file\n" % sizeMissing)
if upstream:
    stderr.write("%d features are dublicated because of bidirectional genes and upstream option\n" % dublCount)
