#!/usr/bin/env python3

import sys
import os
myBinDir = os.path.normpath(os.path.dirname(sys.argv[0]))
sys.path.append(os.path.join(myBinDir, "../lib"))
sys.path.append(os.path.expanduser("/hive/groups/browser/pycbio/lib"))
from collections import defaultdict
import argparse
from pycbio.sys import fileOps
from pycbio.tsv import TsvReader, TabFileReader

# Note:
#  hg19 backmap generates join errors because of this transcript
#  missing
#   ENSG00000168939.6 / ENST00000302805.2 / SPRY3
#
# This is the case of SPRV3 which has now has a transcript past the PAR.
# While addressed in hg38, the state of GENCODE V19 causes issues,
# so we added these to the metadata.
#

def parseArgs():
    desc = """Edit GENCODE metadata tab file use mapped ids.  When there are multiple mappings,lines are duplicated."""
    parser = argparse.ArgumentParser(description=desc)
    parser.add_argument("--verbose", dest="verbose", action="store_true", default=False,
                        help="verbose output for debugging")
    parser.add_argument("--drop", action="append", default=[],
                        help="drop rows with this id, maybe repeated")
    parser.add_argument("targetDb", help="""target UCSC database""")
    parser.add_argument("srcAttrsTsv", help="""attributes from src (newer) database""")
    parser.add_argument("targetAttrsTsv", help="""attributes from target (older) database""")
    parser.add_argument("inTab", help="""tab file to edit""")
    parser.add_argument("outTab", nargs='?', default="/dev/stdout", help="""output tab file""")
    return parser.parse_args()

class GencodeBackMapIdMap(defaultdict):
    """Map GENCODE standard transcript and gene ids to the possible back-mapped ids.  Used to
    update metadata tables"""
    def __init__(self, mappedAttrsTsv):
        super(GencodeBackMapIdMap, self).__init__(set)
        for rec in TsvReader(mappedAttrsTsv):
            self.__addIds(rec)

    @staticmethod
    def __getSrcId(mappedId):
        return mappedId.split('_')[0]

    def __addId(self, mappedId):
        self[self.__getSrcId(mappedId)].add(mappedId)

    def __addIds(self, row):
        self.__addId(row.geneId)
        self.__addId(row.transcriptId)

    def dump(self, fh):
        for id in self.keys():
            sys.stderr.write("{} = {}\n".format(id, " ".join(self[id])))

def mapTabFileRow(srcIdMap, targetIdMap, row, dupCheck, outTabFh):
    rowId = row[0]
    idMap = srcIdMap
    if rowId not in idMap:
        idMap = targetIdMap
    if rowId not in idMap:
        raise Exception("id {} not found in attributes table".format(rowId))
    for backMapId in idMap[rowId]:
        outRow = tuple([backMapId] + row[1:])
        if outRow not in dupCheck:
            fileOps.prRow(outTabFh, outRow)
            dupCheck.add(outRow)

def isBackMapGeneSource(targetDb, inTab):
    # gencode.v39lift37.metadata.Gene_source.gz
    return ((targetDb == "hg19") and
            (inTab.find(".metadata.Gene_source.") >= 0))

def isBackMapTranscriptSource(targetDb, inTab):
    # gencode.v39lift37.metadata.Transcript_source.gz
    return ((targetDb == "hg19") and
            (inTab.find(".metadata.Transcript_source.") >= 0))

def mapTabFileIds(targetDb, srcIdMap, targetIdMap, inTab, dropIds, outTabFh):
    # at least one release had some metadata entries accidentally duplicated, so check
    # for entire row being the same
    dupCheck = set()
    for row in TabFileReader(inTab):
        if row[0] not in dropIds:
            mapTabFileRow(srcIdMap, targetIdMap, row, dupCheck, outTabFh)


def gencodeBackMapMetadataIds(opts):
    dropIds = frozenset(opts.drop)
    srcIdMap = GencodeBackMapIdMap(opts.srcAttrsTsv)
    targetIdMap = GencodeBackMapIdMap(opts.targetAttrsTsv)
    if opts.verbose:
        sys.stderr.write(">>> srcAttrs\n")
        srcIdMap.dump(sys.stderr)
        sys.stderr.write(">>> targetAttrs\n")
        targetIdMap.dump(sys.stderr)
    with open(opts.outTab, "w") as outTabFh:
        mapTabFileIds(opts.targetDb, srcIdMap, targetIdMap, opts.inTab, dropIds, outTabFh)


gencodeBackMapMetadataIds(parseArgs())
