#!/usr/bin/env python3

import sys
import re
import os.path as osp
import argparse
import subprocess
from collections import defaultdict, namedtuple

sys.path.append('/hive/groups/browser/pycbio/lib')
from pycbio.sys.objDict import ObjDict
from pycbio.tsv import TsvReader

# these are for analysis, select which location sources to use
INCL_GENCODE = True
INCL_REFSEQ_CURATED = True
INCL_REFSEQ_OTHER = True

def parseArgs():
    desc = """Generate bigBed and trix indexes for HGNC gene identifiers.
    This uses the GENCODE SQL tables, and RefSeq curated and other data to
    find locations of the HGNC entries.  RefSeq predicted is not used, as
    these wouldn't not be used to define a HGNC locus, only add new
    transcripts.  """

    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("--unmappedTsv",
                        help="Write HGNC entries that couldn't be mapped to this file.  "
                        "Some are expected, include genes not in genome and types not with cooordinates")
    parser.add_argument("ucscDbName", choices=['hg38', 'hg19'],
                        help="UCSC database name")
    parser.add_argument("gencodeVersion",
                        help="GENCODE version, e.g. V47 on hg38 and V47lift37 oh hg19")
    parser.add_argument("hgncFile",
                        help="Path to the HGNC complete set file from: "
                        "https://storage.googleapis.com/public-download-files/hgnc/tsv/tsv/hgnc_complete_set.txt")
    parser.add_argument("hgncBed",
                        help="Output BED file name")
    parser.add_argument("hgncBigBed",
                        help="Output bigBed file name")
    parser.add_argument("trixPrefix",
                        help="prefix file name for Trix input")
    parser.add_argument("locusTypeFilter",
                        help="text for locus_type filter, with RNAs first")
    args = parser.parse_args()
    if not re.fullmatch(r'V\d+(lift37)?', args.gencodeVersion):
        parser.error(f"invalid GENCODE version '{args.gencodeVersion}', "
                     "expected string like `V47' or `V47lift37'")
    return args

# these must match hgncBig62.as;  TSV 'name' is renamed to 'geneName"
stdBedFields = ("chrom", "chromStart", "chromEnd", "name", "score", "strand",
                "thickStart", "thickEnd", "itemRgb")

hgncBedFields = ("symbol", "geneName", "locus_group", "locus_type", "status",
                 "location", "location_sortable", "alias_symbol", "alias_name",
                 "prev_symbol", "prev_name", "gene_group", "gene_group_id",
                 "date_approved_reserved", "date_symbol_changed", "date_name_changed",
                 "date_modified", "entrez_id", "ensembl_gene_id", "vega_id", "ucsc_id",
                 "ena", "refseq_accession", "ccds_id", "uniprot_ids", "pubmed_id", "mgd_id",
                 "rgd_id", "lsdb", "cosmic", "omim_id", "mirbase", "homeodb", "snornabase",
                 "bioparadigms_slc", "orphanet", "pseudogene.org", "horde_id", "merops",
                 "imgt", "iuphar", "kznf_gene_catalog", "mamit-trnadb", "cd", "lncrnadb",
                 "enzyme_id", "intermediate_filament_db", "rna_central_id", "lncipedia",
                 "gtrnadb", "agr", "mane_select", "gencc")
allFields = stdBedFields + hgncBedFields

class HGNCBed(ObjDict):
    "hgncBig62.as in memory, both a dict and an object with fields"

    def __init__(self, chrom, chromStart, chromEnd, name, strand, itemRgb):
        self.chrom = chrom
        self.chromStart = chromStart
        self.chromEnd = chromEnd
        self.name = name
        self.score = 0
        self.strand = strand
        self.thickStart = chromStart
        self.thickEnd = chromEnd
        self.itemRgb = itemRgb

    def __str__(self):
        row = [str(self[f]) for f in allFields]
        return "\t".join(row)

class HgncCoords(namedtuple("Coords",
                            ("hgncId", "chrom", "start", "end", "strand"))):
    pass

RNA_COLOR = "9,99,11"
PROTEIN_CODING_COLOR = "13,20,118"
PSEUDOGENE_COLOR = "253,66,253"
OTHER_COLOR = "0,0,0"

# post-comma removal color coding
LOCUS_TYPE_COLOR_MAP = {
    "Y RNA": RNA_COLOR,
    "cluster RNA": RNA_COLOR,
    "long non-coding RNA": RNA_COLOR,
    "micro RNA": RNA_COLOR,
    "misc RNA": RNA_COLOR,
    "ribosomal RNA": RNA_COLOR,
    "small nuclear RNA": RNA_COLOR,
    "small nucleolar RNA": RNA_COLOR,
    "transfer RNA": RNA_COLOR,
    "vault RNA": RNA_COLOR,
    "T cell receptor gene": PROTEIN_CODING_COLOR,
    "T cell receptor pseudogene": PSEUDOGENE_COLOR,
    "complex locus constituent": OTHER_COLOR,
    "endogenous retrovirus": OTHER_COLOR,
    "fragile site": OTHER_COLOR,
    "gene with protein product": PROTEIN_CODING_COLOR,
    "immunoglobulin gene": OTHER_COLOR,
    "immunoglobulin pseudogene": PSEUDOGENE_COLOR,
    "protocadherin": OTHER_COLOR,
    "pseudogene": PSEUDOGENE_COLOR,
    "readthrough": OTHER_COLOR,
    "region": OTHER_COLOR,
    "unknown": OTHER_COLOR,
    "virus integration site": OTHER_COLOR
}

####
# reading hgnc to coordinates from various sournces
####

def cmdTsvReader(cmd, columnNameMapper=None):
    typeMap = {"txStart": int,
               "txEnd": int}
    with subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True) as proc:
        for row in TsvReader(cmd[0], inFh=proc.stdout, typeMap=typeMap, columnNameMapper=columnNameMapper):
            yield row
        proc.wait()
        if proc.returncode != 0:
            raise subprocess.CalledProcessError(proc.returncode, cmd)


def readHgncTranscriptCoords(tsvReader):
    """Read hgnc id and coordinates for transcript mapped to genome
    into a two dimensional dict in the form [hgncId][chrom].
    TSVReader object can be reading from a database of big bed.
    Needs the columns: hgncId, chrom, txStart, txEnd, strand.
    """
    hgncIdToTransCoords = defaultdict(lambda: defaultdict(list))
    for row in tsvReader:
        hgncIdToTransCoords[row.hgncId][row.chrom].append(row)
    hgncIdToTransCoords.default_factory = None
    return hgncIdToTransCoords

def transCoordsToGeneCoords(transChromRows):
    """"convert a set of transcripts for a given HGNC to gene coordinates"""
    row0 = transChromRows[0]
    start = row0.txStart
    end = row0.txEnd
    for row in transChromRows[1:]:
        start = min(start, row.txStart)
        end = max(end, row.txEnd)
    return HgncCoords(row0.hgncId, row0.chrom, start, end, row0.strand)

def buildHgncCoordsMap(tsvReader):
    """get a reader of HGNC id to transcript coordinates, dict
    of HGNC id to coordinates.  Alts and patches mean a HGNC can be on
    multiple chromosomes"""
    hgncIdToTransCoords = readHgncTranscriptCoords(tsvReader)
    hgncIdCoordsMap = defaultdict(list)
    for hgncId, hgncIdsRows in hgncIdToTransCoords.items():
        for transChromRows in hgncIdsRows.values():
            hgncIdCoordsMap[hgncId].append(transCoordsToGeneCoords(transChromRows))
    return hgncIdCoordsMap

def sqlHgncTransTsvReader(ucscDbName, sql):
    "read using an hgsql query"
    cmd = ["hgsql", ucscDbName, "-e", sql]
    yield from cmdTsvReader(cmd)

def gencodeHgncTransTsvReader(ucscDbName, gencodeVersion):
    "read from GENCODE SQL tables"
    sql = f"""
    SELECT geneId as hgncId, chrom, txStart, txEnd, strand
       FROM wgEncodeGencodeComp{gencodeVersion} comp,
            wgEncodeGencodeGeneSymbol{gencodeVersion} as sym
       WHERE (comp.name = sym.transcriptId)
    UNION
    SELECT geneId as hgncId, chrom, txStart, txEnd, strand
       FROM wgEncodeGencodePseudoGene{gencodeVersion} comp,
            wgEncodeGencodeGeneSymbol{gencodeVersion} as sym
       WHERE (comp.name = sym.transcriptId)
    """
    return sqlHgncTransTsvReader(ucscDbName, sql)

def refSeqCuratedHgncTransTsvReader(ucscDbName):
    "read from RefSeq curate tables"
    sql = """
    SELECT concat("HGNC:", hgnc) as hgncId, chrom, txStart, txEnd, strand
    FROM  ncbiRefSeqCurated rsc, ncbiRefSeqLink rsl
    WHERE (rsc.name = rsl.id) AND (rsl.hgnc != "")
    """
    return sqlHgncTransTsvReader(ucscDbName, sql)

def refSeqOtherNameMapper(colName):
    "bigBed to genePred column names"
    if colName == "chromStart":
        return "txStart"
    elif colName == "chromEnd":
        return "txEnd"
    elif colName == "HGNC":
        return "hgncId"
    else:
        return colName

def refSeqOtherHgncTransTsvReader(ucscDbName):
    "read from RefSeq other bigbed"
    bigBed = f"/gbdb/{ucscDbName}/ncbiRefSeq/ncbiRefSeqOther.bb"
    cmd = ["bigBedToBed", "-tsv", bigBed, "stdout"]
    yield from cmdTsvReader(cmd, refSeqOtherNameMapper)


class HgncIdCoordsMap:
    """Read HGNC coordinates from various sources"""
    def __init__(self, ucscDbName, gencodeVersion):
        self.gencodeMap = self.refseqCuratedMap = self.refseqOtherMap = None
        if INCL_GENCODE:
            self.gencodeMap = buildHgncCoordsMap(gencodeHgncTransTsvReader(ucscDbName, gencodeVersion))
        if INCL_REFSEQ_CURATED:
            self.refseqCuratedMap = buildHgncCoordsMap(refSeqCuratedHgncTransTsvReader(ucscDbName))
        if INCL_REFSEQ_OTHER:
            self.refseqOtherMap = buildHgncCoordsMap(refSeqOtherHgncTransTsvReader(ucscDbName))

    def get(self, hgncId, default=None):
        """get list HgncCoords object for a HGNC ID or None"""
        hgncCoords = None
        if (hgncCoords is None) and INCL_GENCODE:
            hgncCoords = self.gencodeMap.get(hgncId)
        if (hgncCoords is None) and INCL_REFSEQ_CURATED:
            hgncCoords = self.refseqCuratedMap.get(hgncId)
        if (hgncCoords is None) and INCL_REFSEQ_OTHER:
            hgncCoords = self.refseqOtherMap.get(hgncId)
        if hgncCoords is None:
            return default
        return hgncCoords

###
# process of HGNC data
###
def loadHgncGenes(hgncFile):
    return list(TsvReader(hgncFile))

def hgncGenePrevSymbols(prevSymbols):
    "split into a list"
    return prevSymbols.strip('"').split('|')

def getLocusColor(locusType):
    color = LOCUS_TYPE_COLOR_MAP.get(locusType, None)
    if color is None:
        raise Exception(f"No color for locus_type `{locusType}', most likely HGNC added a new type")
    return color

def renameLocusType(locusType):
    """bigBed filters can't include commas due to trackDb syntax,
    'RNA, micro' -> 'micro RNA'"""
    if locusType.startswith("RNA, "):
        parts = locusType.split(", ")
        return f"{parts[1]} RNA"
    return locusType

def hgncGeneToBed(hgncGene, coords):
    locusType = renameLocusType(hgncGene.locus_type)
    bed = HGNCBed(coords.chrom, coords.start, coords.end, hgncGene.hgnc_id, coords.strand,
                  getLocusColor(locusType))

    # copy remaining fields with a couple of special cases
    for fld in hgncBedFields:
        if fld == "locus_type":
            bed[fld] = locusType
        elif fld == "geneName":
            bed[fld] = hgncGene.name
        else:
            # remove quotes
            bed[fld] = getattr(hgncGene, fld).strip('"')
    return bed

def hgncGeneToBeds(hgncGene, hgncIdCoordsMap, hgncMappedIds):
    all_coords = hgncIdCoordsMap.get(hgncGene.hgnc_id)
    if all_coords is None:
        return []
    else:
        hgncMappedIds.add(hgncGene.hgnc_id)
        return [hgncGeneToBed(hgncGene, coords) for coords in all_coords]

def generateHgncBeds(hgncGenes, hgncIdCoordsMap, hgncMappedIds):
    hgncBeds = []
    for hgncGene in hgncGenes:
        beds = hgncGeneToBeds(hgncGene, hgncIdCoordsMap, hgncMappedIds)
        hgncBeds.extend(beds)
    return hgncBeds

def writeHgncBeds(hgncBeds, bedFh):
    for bed in sorted(hgncBeds, key=lambda b: (b.chrom, b.chromStart, b.chromEnd)):
        print(str(bed), file=bedFh)

def buildBigBed(ucscDbName, hgncBedFile, hgncBigBed):
    # .as relative to this file
    libDir = osp.join(osp.dirname(__file__), "../../../lib")
    asFile = osp.join(libDir, "hgncBig62.as")
    chromSizes = f"/cluster/data/{ucscDbName}/chrom.sizes"
    cmd = ["bedToBigBed", "-extraIndex=name", "-tab", "-type=bed9+53",
           "-as=" + asFile, hgncBedFile, chromSizes, hgncBigBed]
    subprocess.run(cmd, check=True)

def trixValues(hgncBed):
    # prev_symbol MSK16|CSPG1|AGC1
    prev_symbols = hgncGenePrevSymbols(hgncBed.prev_symbol)
    return hgncBed.symbol + " " + " ".join(prev_symbols) + ';'

def buildTrixInput(hgncBeds, trixInFh):
    # alts & patches can result in multiple entries
    done = set()
    for hgncBed in hgncBeds:
        if hgncBed.name not in done:
            print(hgncBed.name, trixValues(hgncBed), sep='\t', file=trixInFh)
            done.add(hgncBed.name)

def buildTrix(hgncBeds, trixPrefix):
    trixIn = trixPrefix + ".trixin"
    with open(trixIn, 'w') as trixInFh:
        buildTrixInput(hgncBeds, trixInFh)
    cmd = ["ixIxx", "-maxWordLength=32", trixIn, trixPrefix + ".ix", trixPrefix + ".ixx"]
    subprocess.run(cmd, check=True)

def writeLocusTypeFilter(hgncBeds, locusTypeFilter):
    rnaTypes = set()
    otherTypes = set()
    # split by RNA, and others
    for hgncBed in hgncBeds:
        if "RNA" in hgncBed.locus_type:
            rnaTypes.add(hgncBed.locus_type)
        else:
            otherTypes.add(hgncBed.locus_type)

    filterVal = (",".join(sorted(rnaTypes)) + "," +
                 ",".join(sorted(otherTypes)))
    with open(locusTypeFilter, 'w') as fh:
        print(filterVal, file=fh)

def writeUnmappedTsv(hgncGenes, hgncMappedIds, tsvFh):
    hgncGenes[0].writeHeader(tsvFh)
    for row in hgncGenes:
        if row.hgnc_id not in hgncMappedIds:
            row.write(tsvFh)

def hgncToBigBed(ucscDbName, gencodeVersion, hgncFile, hgncBedFile, hgncBigBed,
                 trixPrefix, locusTypeFilter, unmappedTsv):
    hgncIdCoordsMap = HgncIdCoordsMap(ucscDbName, gencodeVersion)
    hgncGenes = loadHgncGenes(hgncFile)
    hgncMappedIds = set()
    hgncBeds = generateHgncBeds(hgncGenes, hgncIdCoordsMap, hgncMappedIds)

    with open(hgncBedFile, 'w') as bedFh:
        writeHgncBeds(hgncBeds, bedFh)
    buildBigBed(ucscDbName, hgncBedFile, hgncBigBed)
    buildTrix(hgncBeds, trixPrefix)
    writeLocusTypeFilter(hgncBeds, locusTypeFilter)
    if unmappedTsv is not None:
        with open(unmappedTsv, 'w') as tsvFh:
            writeUnmappedTsv(hgncGenes, hgncMappedIds, tsvFh)

def main():
    args = parseArgs()
    hgncToBigBed(args.ucscDbName, args.gencodeVersion, args.hgncFile,
                 args.hgncBed, args.hgncBigBed, args.trixPrefix,
                 args.locusTypeFilter, args.unmappedTsv)

if __name__ == "__main__":
    main()
