#!/usr/bin/env python3

import sys
import os
myBinDir = os.path.normpath(os.path.dirname(sys.argv[0]))
sys.path.append(os.path.join(myBinDir, "../lib"))
sys.path.append(os.path.expanduser("/hive/groups/browser/pycbio/lib"))
import argparse
from pycbio.sys.symEnum import SymEnum
from gencode.gencodeGenes import GencodeGenes, GencodeFunction, GencodeTag

GencodeSubset = SymEnum("GencodeSubset", ("pseudogene", "basic", "comp"))


def parseArgs():
    desc = """%prog [options] gencodeSubset gencodeGp gencodeAttrs outGp

    Extracts genePred records for a track based on various criteria and sort
    results.
"""
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("gencodeSubset", type=GencodeSubset, choices=GencodeSubset,
                        help="Subset of GENCODE to output")
    parser.add_argument("gencodeGp",
                        help="full GENCODE genePred")
    parser.add_argument("gencodeAttrs",
                        help="gencode attributes TSV")
    parser.add_argument("subsetGp",
                        help="track genePred")
    return parser.parse_args()

def useTranscript(gencodeSubset, transcript):
    if gencodeSubset == GencodeSubset.pseudogene:
        return transcript.getFunction() == GencodeFunction.pseudo
    elif gencodeSubset == GencodeSubset.basic:
        return (GencodeTag.basic in transcript.tags) and (transcript.getFunction() != GencodeFunction.pseudo)
    elif gencodeSubset == GencodeSubset.comp:
        return (transcript.getFunction() != GencodeFunction.pseudo)
    else:
        assert False, f"gencodeSubset '{gencodeSubset}' not handled"

def filterGeneLocus(gencodeSubset, geneLocus, genePreds):
    for transcriptLocus in geneLocus.transcriptLoci:
        if useTranscript(gencodeSubset, transcriptLocus.transcript):
            genePreds.append(transcriptLocus.gp)

def filterGenes(gencode, gencodeSubset):
    genePreds = []
    for gene in gencode.genesById.values():
        for geneLocus in gene.geneLoci:
            filterGeneLocus(gencodeSubset, geneLocus, genePreds)
    return genePreds

def writeSubset(genePreds, gpFh):
    genePreds.sort(key=lambda g: (g.chrom, g.txStart, g.txEnd, g.name))
    for gp in genePreds:
        gp.write(gpFh)


def gencodeMakeTracks(opts):
    gencode = GencodeGenes.loadFromFiles(opts.gencodeGp, opts.gencodeAttrs)
    genePreds = filterGenes(gencode, opts.gencodeSubset)
    with open(opts.subsetGp, "w") as gpFh:
        writeSubset(genePreds, gpFh)


gencodeMakeTracks(parseArgs())
