#!/usr/bin/env python

import logging, sys, optparse, gzip, subprocess
from collections import defaultdict
from os.path import join, basename, dirname, isfile

# ==== functions =====

# Colors by rule
RULE_COLORS = {
    1: "255,0,0",       # red - last 50bp of last coding junction
    2: "255,140,0",     # orange - intronless transcript
    3: "139,0,0",       # dark red - first 100bp of coding nucleotides
    4: "255,215,0",     # gold - long coding exon (>400 nt)
}

RULE_DESCRIPTIONS = {
    1: "<b>Rule 1</b> - CDS within 50bp of the last splice junction",
    2: "<b>Rule 2</b> - single coding exon, no 3'UTR intron",
    3: "<b>Rule 3</b> - first 100bp of coding nucleotides",
    4: "<b>Rule 4</b> - PTC in long exon (>400 nt)",
}

# Lindeboom et al. 2016: reduced NMD efficiency for PTCs in exons longer than this
LONG_EXON_THRESHOLD = 400

def parseArgs():
    " setup logging, parse command line arguments and options. -h shows auto-generated help page "
    parser = optparse.OptionParser("usage: %prog [options] inFname outDecoFname outCollapsedFname - "
        "Output BEDs with NMD escape regions. First output is decorator format, "
        "second is collapsed bigGenePred with gene symbols and transcript lists.")

    parser.add_option("-d", "--debug", dest="debug", action="store_true",
        help="show debug messages")
    parser.add_option("-f", "--format", dest="format", action="store", default="genePredExt",
        help="Input format: 'genePredExt' (with bin column, e.g. ncbiRefSeq) or "
             "'bigGenePred' (e.g. gencode .bb file, will use bigBedToBed). Default: genePredExt")
    parser.add_option("--gene-sym-field", dest="geneSymField", type="int", default=17,
        help="bigGenePred 0-based field index that holds the gene symbol. "
             "Default 17 (Gencode 'geneName'). MANE puts the HGNC symbol in field 18.")
    parser.add_option("--ncbi-id-field", dest="ncbiIdField", type="int", default=-1,
        help="bigGenePred 0-based field index that holds the NCBI RefSeq accession "
             "(NM_/NR_). Default -1 (not extracted). For MANE, pass 21.")
    parser.add_option("--no-collapse", dest="noCollapse", action="store_true", default=False,
        help="Emit one row per (transcript, region) instead of collapsing identical "
             "coordinates across transcripts. Useful for sets like MANE where each "
             "label-field column should be a single value.")
    parser.add_option("--rule1-mode", dest="rule1Mode", action="store", default="cds",
        choices=["cds", "mrna"],
        help="How to count the 50 bp walk-back from the last splice junction. "
             "'cds' (default): count only CDS nucleotides, skipping 3'UTR; "
             "paints up to 50 bp of CDS regardless of how far the last junction "
             "sits past the stop codon. "
             "'mrna': count mRNA nucleotides including 3'UTR, then clip output "
             "to CDS; when the last junction is more than 50 mRNA-bp past the "
             "stop codon, no CDS position is painted (tracks the 55 bp rule "
             "literature more closely).")
    (options, args) = parser.parse_args()

    if len(args) != 3:
        parser.print_help()
        exit(1)

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)
        logging.getLogger().setLevel(logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
        logging.getLogger().setLevel(logging.INFO)

    return args, options

def openInput(fname, fmt, geneSymField=17, ncbiIdField=-1):
    """Open input file and yield parsed transcript dicts.
    Both formats yield dicts with keys: name, chrom, strand, txStart, txEnd,
    cdsStart, cdsEnd, exonCount, exonStarts, exonEnds, geneSym, ncbiId,
    cdsStartStat, cdsEndStat, exonFrames
    """
    if fmt == "bigGenePred":
        # pipe through bigBedToBed
        proc = subprocess.Popen(["bigBedToBed", fname, "stdout"],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        fh = proc.stdout
    elif fname.endswith(".gz"):
        fh = gzip.open(fname, "rt")
    else:
        fh = open(fname)

    for line in fh:
        line = line.rstrip("\n\r")
        if not line:
            continue
        fields = line.split("\t")

        if fmt == "genePredExt":
            # genePredExt with bin column:
            # bin, name, chrom, strand, txStart, txEnd, cdsStart, cdsEnd,
            # exonCount, exonStarts, exonEnds, score, name2, cdsStartStat, cdsEndStat, exonFrames
            rec = {
                "name": fields[1],
                "chrom": fields[2],
                "strand": fields[3],
                "txStart": int(fields[4]),
                "txEnd": int(fields[5]),
                "cdsStart": int(fields[6]),
                "cdsEnd": int(fields[7]),
                "exonCount": int(fields[8]),
                "exonStarts": [int(x) for x in fields[9].strip(",").split(",") if x],
                "exonEnds": [int(x) for x in fields[10].strip(",").split(",") if x],
                "geneSym": fields[12] if len(fields) > 12 else "",
                "ncbiId": "",
                "cdsStartStat": fields[13] if len(fields) > 13 else "none",
                "cdsEndStat": fields[14] if len(fields) > 14 else "none",
                "exonFrames": fields[15].strip(",") if len(fields) > 15 else "",
            }
        elif fmt == "bigGenePred":
            # bigGenePred:
            # chrom, chromStart, chromEnd, name, score, strand, thickStart, thickEnd,
            # color, blockCount, blockSizes, chromStarts, name2, cdsStartStat, cdsEndStat,
            # exonFrames, type, geneName, geneName2, geneType
            chromStart = int(fields[1])
            blockSizes = [int(x) for x in fields[10].strip(",").split(",") if x]
            blockStarts = [int(x) for x in fields[11].strip(",").split(",") if x]
            rec = {
                "name": fields[3],
                "chrom": fields[0],
                "strand": fields[5],
                "txStart": chromStart,
                "txEnd": int(fields[2]),
                "cdsStart": int(fields[6]),
                "cdsEnd": int(fields[7]),
                "exonCount": int(fields[9]),
                "exonStarts": [chromStart + s for s in blockStarts],
                "exonEnds": [chromStart + s + sz for s, sz in zip(blockStarts, blockSizes)],
                "geneSym": fields[geneSymField] if len(fields) > geneSymField else "",
                "ncbiId": fields[ncbiIdField] if (ncbiIdField >= 0 and len(fields) > ncbiIdField) else "",
                "cdsStartStat": fields[13] if len(fields) > 13 else "none",
                "cdsEndStat": fields[14] if len(fields) > 14 else "none",
                "exonFrames": fields[15].strip(",") if len(fields) > 15 else "",
            }
        else:
            raise ValueError("Unknown format: " + fmt)

        yield rec

    if fmt == "bigGenePred":
        proc.wait()

def bedOut(row, txStart, txEnd, ofh, rule):
    "write a decorator BED line"
    row = [str(x) for x in row]
    chrom, start, end, name = row
    decItem = chrom+":"+str(txStart)+"-"+str(txEnd)+":"+name
    color = RULE_COLORS[rule]
    mouseover = RULE_DESCRIPTIONS[rule]
    row = [chrom, start, end, name, "0", ".", start, end, color, "1",
           str(int(end)-int(start)), "0", decItem, "block", color, "", mouseover]
    ofh.write("\t".join(row))
    ofh.write("\n")


def rule1Regions(rec, nUpstreamJxn=50, mode="cds"):
    """50 bp rule: paint CDS positions near the transcript's last splice
    junction. The last junction is taken over ALL exons of the transcript,
    including 3'UTR introns, because those introns deposit EJCs that can
    trigger NMD.

    Two modes control how the nUpstreamJxn walk-back from the last junction
    is counted:
      mode='cds'  - count only CDS nucleotides, skipping any 3'UTR as you
                    walk upstream. Paints up to nUpstreamJxn bp of CDS
                    regardless of how far the last junction sits past the
                    stop codon. This is what the 'last 50 bp of last
                    junction' label most literally means in terms of CDS.
      mode='mrna' - count mRNA nucleotides (including 3'UTR). If the
                    walk-back stays in 3'UTR the whole way, no CDS is
                    painted; otherwise the CDS overlap of the walked
                    window is painted. This tracks the 55 bp rule
                    literature, where the distance is an mRNA distance.

    Plus any CDS downstream of the last junction (the CDS overlap with the
    last mRNA exon). Returns a list of (genStart, genEnd) tuples in
    genomic order, CDS-only."""
    strand = rec["strand"]
    cdsStart = rec["cdsStart"]
    cdsEnd = rec["cdsEnd"]
    exonStarts = rec["exonStarts"]
    exonEnds = rec["exonEnds"]

    if len(exonStarts) < 2:
        return []  # no splice junction -> rule does not apply

    exons = list(zip(exonStarts, exonEnds))
    if strand == "-":
        exons = list(reversed(exons))  # exons now in mRNA 5'->3' order

    def clipToCds(s, e):
        s2, e2 = max(s, cdsStart), min(e, cdsEnd)
        return (s2, e2) if s2 < e2 else None

    out = []
    # everything downstream of the last junction = last mRNA exon clipped to CDS
    clipped = clipToCds(*exons[-1])
    if clipped:
        out.append(clipped)

    # walk upstream of the last junction through earlier exons in mRNA order
    remaining = nUpstreamJxn
    for i in range(len(exons) - 2, -1, -1):
        if remaining <= 0:
            break
        exStart, exEnd = exons[i]

        if mode == "cds":
            # count only CDS bp; skip 3'UTR-only exons entirely
            cds = clipToCds(exStart, exEnd)
            if not cds:
                continue
            cdsS, cdsE = cds
            cdsLen = cdsE - cdsS
            if remaining >= cdsLen:
                out.append((cdsS, cdsE))
                remaining -= cdsLen
            else:
                # take the 3'-most `remaining` CDS bp
                if strand == "+":
                    out.append((cdsE - remaining, cdsE))
                else:
                    out.append((cdsS, cdsS + remaining))
                remaining = 0

        else:  # mode == "mrna"
            exLen = exEnd - exStart
            if remaining >= exLen:
                chunk = (exStart, exEnd)
                remaining -= exLen
            else:
                # take the 3'-most `remaining` mRNA bp
                if strand == "+":
                    chunk = (exEnd - remaining, exEnd)
                else:
                    chunk = (exStart, exStart + remaining)
                remaining = 0
            clipped = clipToCds(*chunk)
            if clipped:
                out.append(clipped)

    return sorted(out)


def outputExonsUpTo(from3Prime, cdsExons, chrom, txStart, txEnd, name, n, ofh, rule):
    """ given a list of (start, end), output start-end BEDs to ofh that cover n nucleotides.
        Returns list of (chrom, start, end) regions output. """
    doneNs = 0
    doStop = False
    regions = []
    if from3Prime:
        cdsExons = list(reversed(cdsExons))

    # -50 means "-50 from the last junction" so take length of last exon + 50
    if n < 0:
        ex1Start = cdsExons[0][0]
        ex1End = cdsExons[0][1]
        n = (ex1End-ex1Start)+50

    for start, end in cdsExons:
        if doneNs >= n:
            return regions
        exLen = end-start
        missBps = n-doneNs
        if doneNs+exLen > n:
            if from3Prime:
                start = end-missBps
            else:
                end = start+missBps
            doStop = True
        bed = [chrom, str(start), str(end), name]
        doneNs += exLen
        bedOut(bed, txStart, txEnd, ofh, rule)
        regions.append( (chrom, start, end) )
        if doStop:
            return regions
    return regions

# ----------- main --------------
def main():
    args, options = parseArgs()

    inFname, outDecoFname, outCollapsedFname = args

    decoOfh = open(outDecoFname, "w")

    # collect regions for the collapsed output:
    # key = (chrom, start, end, rule) -> {"transcripts": [...], "ncbiIds": [...],
    #                                     "strand": set(), "geneSym": str}
    regionData = defaultdict(lambda: {"transcripts": [], "ncbiIds": [], "strands": set(), "geneSym": ""})

    for rec in openInput(inFname, options.format, options.geneSymField, options.ncbiIdField):
        name = rec["name"]
        chrom = rec["chrom"]
        strand = rec["strand"]
        txStart = rec["txStart"]
        txEnd = rec["txEnd"]
        cdsStart = rec["cdsStart"]
        cdsEnd = rec["cdsEnd"]

        # skip non-coding transcripts (cdsStart == cdsEnd)
        if cdsStart >= cdsEnd:
            continue

        # gene symbol from record, fall back to transcript name
        geneSym = rec["geneSym"]

        # only keep exons that have CDS and cut around CDS
        cdsExons = []
        for exStart, exEnd in zip(rec["exonStarts"], rec["exonEnds"]):
            # 5' UTR
            if cdsStart > exEnd:
                continue
            # 3' UTR
            if exStart > cdsEnd:
                continue
            if (exStart <= cdsStart and cdsStart <= exEnd):
                exStart = cdsStart
            if (exStart <= cdsEnd and cdsEnd <= exEnd):
                exEnd = cdsEnd
            # skip degenerate cdsExons (CDS boundary lands exactly on exon boundary)
            if exStart >= exEnd:
                continue
            cdsExons.append( (exStart, exEnd) )

        ncbiId = rec.get("ncbiId", "")

        def addRegions(regions, rule):
            for r in regions:
                if options.noCollapse:
                    key = (r[0], r[1], r[2], rule, name)
                else:
                    key = (r[0], r[1], r[2], rule)
                regionData[key]["transcripts"].append(name)
                if ncbiId:
                    regionData[key]["ncbiIds"].append(ncbiId)
                regionData[key]["strands"].add(strand)
                if geneSym and not regionData[key]["geneSym"]:
                    regionData[key]["geneSym"] = geneSym

        # rule 2 applies when no EJC can be deposited downstream of the stop codon.
        # that requires: single CDS exon (no CDS-CDS junction) AND no 3'UTR intron.
        # 5'UTR introns are allowed - their EJCs are upstream of the stop and either
        # are cleared by the scanning 40S or are never encountered by the terminating
        # ribosome, so they do not trigger NMD.
        if strand == "+":
            hasThreeUtrIntron = any(s > cdsEnd for s in rec["exonStarts"])
        else:
            hasThreeUtrIntron = any(e < cdsStart for e in rec["exonEnds"])

        if len(cdsExons) == 1 and not hasThreeUtrIntron:
            # rule 2: no EJC downstream of stop codon -> any PTC in CDS escapes NMD
            bed = [chrom, str(cdsStart), str(cdsEnd), name]
            bedOut(bed, txStart, txEnd, decoOfh, 2)
            addRegions([(chrom, cdsStart, cdsEnd)], 2)
        else:
            # rule 3: first 100bp of coding nucleotides
            if strand == "+":
                regions = outputExonsUpTo(False, cdsExons, chrom, txStart, txEnd, name, 100, decoOfh, 3)
            else:
                regions = outputExonsUpTo(True, cdsExons, chrom, txStart, txEnd, name, 100, decoOfh, 3)
            addRegions(regions, 3)

            # rule 1: within 50 mRNA-bp upstream of the last splice junction of
            # the transcript (incl. 3'UTR introns), plus anything downstream of
            # that junction, clipped to CDS.
            rule1Tuples = []
            for s, e in rule1Regions(rec, 50, mode=options.rule1Mode):
                bed = [chrom, str(s), str(e), name]
                bedOut(bed, txStart, txEnd, decoOfh, 1)
                rule1Tuples.append((chrom, s, e))
            addRegions(rule1Tuples, 1)

            # rule 4: coding exons >400 nt that have a downstream splice junction.
            # Without a 3'UTR intron, the last CDS exon is the last mRNA exon
            # and is excluded (no downstream EJC). With a 3'UTR intron, every
            # CDS exon has a downstream junction and is eligible.
            if hasThreeUtrIntron:
                rule4Exons = cdsExons
            elif strand == "+":
                rule4Exons = cdsExons[:-1]
            else:
                rule4Exons = cdsExons[1:]

            for exStart, exEnd in rule4Exons:
                if exEnd - exStart > LONG_EXON_THRESHOLD:
                    bed = [chrom, str(exStart), str(exEnd), name]
                    bedOut(bed, txStart, txEnd, decoOfh, 4)
                    addRegions([(chrom, exStart, exEnd)], 4)

    decoOfh.close()

    # write collapsed output as bed 9 + mouseover + transcripts + ncbiIds
    collOfh = open(outCollapsedFname, "w")
    for key, data in sorted(regionData.items()):
        chrom, start, end, rule = key[:4]
        txList = sorted(set(data["transcripts"]))
        ncbiList = sorted(set(data["ncbiIds"]))
        geneSym = data["geneSym"]
        if not geneSym:
            geneSym = txList[0]

        # pick strand: use the strand if all transcripts agree, else "."
        strands = data["strands"]
        strand = list(strands)[0] if len(strands) == 1 else "."

        color = RULE_COLORS[rule]
        txListStr = ",".join(txList)
        ncbiListStr = ",".join(ncbiList)
        if len(txList) <= 3:
            mouseover = RULE_DESCRIPTIONS[rule] + " (" + ", ".join(txList) + ")"
        else:
            mouseover = RULE_DESCRIPTIONS[rule] + " (" + str(len(txList)) + " transcripts)"

        row = [chrom, str(start), str(end), geneSym, "0", strand,
               str(start), str(end), color, mouseover, txListStr, ncbiListStr]
        collOfh.write("\t".join(row))
        collOfh.write("\n")

    collOfh.close()
    logging.info("Wrote %d decorator regions to %s" % (sum(len(d["transcripts"]) for d in regionData.values()), outDecoFname))
    logging.info("Wrote %d collapsed regions to %s" % (len(regionData), outCollapsedFname))

main()
