#!/usr/bin/env python3
import sys
import os.path as osp
import re
import argparse
import glob
from collections import namedtuple, defaultdict

###
# DO NOT CODE REVIEW THIS.  IT WILL NEVER BE USED AGAIN.  SPEND OUR FUNDING ON
# SOMETHING IMPORTANT OR JUST REVIEW THE TRACKS.
###

sys.path.insert(0, '/hive/groups/browser/pycbio/lib')
from pycbio.sys import cmdOps, fileOps
from pycbio.tsv import TsvReader
from pycbio.hgdata.bed import Bed, BedReader
from pycbio.hgdata.autoSql import strArrayJoin

def parseArgs():
    desc = """convert CLS BED and attributes to a BED with attributes.

    Input files create with:

    zcat Hv3_masterTable_refined.gtf.gz | gffread -F | gff3ToGenePred -attrsOut=Hv3_masterTable_refined.attrs.tab stdin stdout |\\
       genePredToBed stdin stdout | tawk '{$7=$2; print}' | sort -k1,1 -k2,2n |gzip -9 > Hv3_masterTable_refined.bed.gz

    CLS metadata documented at: https://github.com/guigolab/gencode-cls-master-table
    other info: https://github.com/guigolab/CLS3_GENCODE
    """
    parser = argparse.ArgumentParser(description=desc)
    subparsers = parser.add_subparsers(dest="command", required=True)

    # make-beds subcommand
    beds_parser = subparsers.add_parser("make-beds")
    beds_parser.add_argument("clsMetadataTab")
    beds_parser.add_argument("masterBedFile")
    beds_parser.add_argument("transcriptMetaTsv")
    beds_parser.add_argument("trackDir")

    # make-trackdb subcommand
    trackdb_parser = subparsers.add_parser("make-trackdb")
    trackdb_parser.add_argument("--parent")
    trackdb_parser.add_argument("clsMetadataTab")
    trackdb_parser.add_argument("trackDir")
    trackdb_parser.add_argument("trackDb")

    # check-bams subcommand
    check_parser = subparsers.add_parser("check-bams")
    check_parser.add_argument("clsMetadataTab")
    check_parser.add_argument("trackDir")

    return cmdOps.parse(parser)

##
# parsing of sample ids: SIDHBrAOC0301
#
#  https://github.com/guigolab/gencode-cls-master-table?tab=readme-ov-file
#
#    0 Fixed prefix SID
#    1 Single letter code for the organism H (human), M (mouse)
#    2 Two letter code indicating the tissue See Tissue codes
#    3 Single letter code for the stage A (adult), E (embryo), P (placenta)
#    4 Single letter code for the sequencing technology O (ont), P (PacBio)
#    5 Single letter code for capture status P (pre-capture), C (post-capture)
#    6 Two digit code for the biological replicate -
#    7 Two digit code for the technical replicate -
###
class ClsMeta(namedtuple("ClsMeta",
                         ("sid",
                          "org_code", "tissue_code", "stage_code", "platform_code", "capture_code",
                          'tissue', 'stage', 'platform', 'capture', "bioRep", "techRep"))):
    "parsed sample id and other meta"
    pass

def parseClsMetadataRow(row):
    pattern = r"^(SID)([HM])([A-Za-z]{2})([AEP])([OP])([PC])(\d{2})(\d{2})$"
    match = re.match(pattern, row.sid)
    if not match:
        raise Exception(f"invalid sample id `{row.sid}'")
    org_code, tissue_code, stage_code, platform_code, capture_code, bioRep, techRep = match.groups()[1:]
    return ClsMeta(row.sid, org_code, tissue_code, stage_code, platform_code, capture_code,
                   row.tissue, row.stage, row.platform, row.capture, bioRep, techRep)

def loadClsMetadata(clsMetadataTab):
    clsMetaBySid = {}
    # SIDHBrAOC0301	Brain	Adult	ont	post-capture	03	01
    columns = ('sid', 'tissue', 'stage', 'platform', 'capture', 'bioRep', 'techRep')
    for row in TsvReader(clsMetadataTab, columns=columns):
        clsMeta = parseClsMetadataRow(row)
        clsMetaBySid[clsMeta.sid] = clsMeta
    return clsMetaBySid

##
# sample id to BAM name conversion
#
# mm10 pre-capture
#   ont-Crg-CapTrap_MpreCap_0+_Brain01Rep1.bam
#   pacBioSII-Cshl-CapTrap_MpreCap_0+_Brain01Rep1.bam
#   pacBioSII-Cshl-CapTrap_MpreCap_0+_EmbHeart01Rep1.bam
# mm10 post-capture
#   ont-Crg-CapTrap_Mv2_0+_Brain01Rep1.bam
#   pacBioSII-Cshl-CapTrap_Mv2_0+_Brain01Rep1.bam
#   pacBioSII-Cshl-CapTrap_Mv2_0+_EmbBrain01Rep1.bam
# hg38 pre-capture
#   ont-Crg-CapTrap_HpreCap_0+_Brain03Rep1.bam
#   pacBioSII-Cshl-CapTrap_HpreCap_0+_Brain03Rep1.bam
# hg38 post-capture
#   ont-Crg-CapTrap_Hv3_0+_Brain03Rep1.bam
#   pacBioSII-Cshl-CapTrap_Hv3_0+_Brain03Rep1.bam
#

# several of these need to be tried to get BAM name
# as Embryonic Wb is overloaded
human_stage_tissue_code_to_name = {
    "ACp": "CpoolA",
    "PPl": "Placenta",
    "EWb": "iPSC",
}
mouse_stage_tissue_code_to_name = {
    "EWb": "EmbSC",
}
both_stage_tissue_code_to_name = {
    "ABr": "Brain",
    "AHe": "Heart",
    "ALi": "Liver",
    "ATe": "Testis",
    "ATp": "TpoolA",
    "AWb": "WBlood",
    "EBr": "EmbBrain",
    "EHe": "EmbHeart",
    "ELi": "EmbLiver",
}
org_code_to_ver = {"H": "Hv3_0",
                   "M": "Mv2_0"}

platform_code_to_platform_lab = {"O": "ont-Crg",
                                 "P": "pacBioSII-Cshl"}
capture_code_to_capture_dir = {"P": "pre-capture",
                               "C": "post-capture"}

def stageTissueMapName(clsMeta):
    stageTissue = clsMeta.stage_code + clsMeta.tissue_code
    if clsMeta.org_code == 'H':
        name = human_stage_tissue_code_to_name.get(stageTissue)
    else:
        name = mouse_stage_tissue_code_to_name.get(stageTissue)
    if name is None:
        name = both_stage_tissue_code_to_name.get(stageTissue)
    if name is None:
        raise Exception(f"BUG: can't get state tissue name for '{stageTissue}', id {clsMeta.sid}")
    return name

def sampleIdToFileName(clsMeta):
    # this a bit ugly
    fname = platform_code_to_platform_lab[clsMeta.platform_code] + "-CapTrap_"
    if clsMeta.capture_code == 'P':
        fname += clsMeta.org_code + "preCap_0"
    else:
        fname += org_code_to_ver[clsMeta.org_code]
    fname += ("+_" + stageTissueMapName(clsMeta) + clsMeta.bioRep +
              "Rep" + clsMeta.techRep.strip('0'))
    return fname

def sampleIdToRelPath(clsMeta, ext):
    return osp.join(capture_code_to_capture_dir[clsMeta.capture_code],
                    sampleIdToFileName(clsMeta) + ext)

def sampleIdToBamPath(clsMeta):
    return sampleIdToRelPath(clsMeta, ".bam")

def sampleIdToBedPath(clsMeta):
    return sampleIdToRelPath(clsMeta, ".bed.gz")

def sampleIdToBigBedPath(clsMeta):
    return sampleIdToRelPath(clsMeta, ".bb")

def stageTissueName(clsMeta):
    if clsMeta.stage_code == 'E':
        return 'Emb' + clsMeta.tissue
    else:
        return clsMeta.tissue

def stageTissueToBedPath(stageTissue):
    return osp.join("cls-models-" + stageTissue + ".bed.gz")

def stageTissueToBigBedPath(stageTissue):
    return osp.join("cls-models-" + stageTissue + ".bb")

platform_desc_map = {
    "ont": "ONT",
    "pacBioSII": "PacBio"
}

platform_short_desc_map = {
    "ont": "ONT",
    "pacBioSII": "PB"
}

def shortCapture(clsMeta):
    return clsMeta.capture.split('-')[0]

def getStageDesc(clsMeta, short=False):
    if (clsMeta.stage == "Placenta") or (clsMeta.tissue.find('pool') >= 0):
        return None
    elif short:
        if clsMeta.stage == "Adult":
            return None
        elif clsMeta.stage == "Embryo":
            return "Emb"
        else:
            return clsMeta.stage
    else:
        if clsMeta.stage == "Embryo":
            return "Embryonic"
        else:
            return clsMeta.stage

def clsMetaToSampleDesc(clsMeta, short=False):
    desc = []
    stage = getStageDesc(clsMeta, short)
    if stage is not None:
        desc.append(stage)
    if clsMeta.tissue == "CpoolA":
        tissue = "Cell Line Pool"
    elif clsMeta.tissue == "TpoolA":
        tissue = "Tissue Pool"
    elif clsMeta.tissue == "WBlood":
        tissue = "Blood"
    else:
        tissue = clsMeta.tissue
    desc.append(tissue)
    return " ".join(desc)

def clsMetaToDesc(clsMeta, short=False):
    if short:
        platform = platform_short_desc_map[clsMeta.platform]
        capture = shortCapture(clsMeta)
    else:
        platform = platform_desc_map[clsMeta.platform]
        capture = clsMeta.capture
    return clsMetaToSampleDesc(clsMeta, short) + " " + platform + " " + capture

##
# GTF meta
##
def parseSidList(sids):
    return tuple(sorted(sids.split(',')))

def loadTranscriptMeta(transcriptMetaTsv):
    # transcriptId	samplesMetadata	artifact
    typeMap = {
        "samplesMetadata": parseSidList
    }
    return {row.transcriptId: row
            for row in TsvReader(transcriptMetaTsv, typeMap=typeMap)}

##
# BED generation
##
def allModelsFile(trackDir, ext):
    return osp.join(trackDir, "cls-models" + ext)

def makeTranscriptBed(bed, clsMetaBySid, transcriptMeta, allBeds, bedsBySid,
                      bedsByStageTissue):
    sids = transcriptMeta.samplesMetadata
    sampleDesc = [clsMetaToDesc(clsMetaBySid[sid]) for sid in sids]
    bed.extraCols = [len(sids), strArrayJoin(sampleDesc), strArrayJoin(sids)]

    allBeds.append(bed)
    for sid in sids:
        bedsBySid[sid].append(bed)

    stageTissues = set([stageTissueName(clsMetaBySid[sid]) for sid in sids])
    for stageTissue in stageTissues:
        bedsByStageTissue[stageTissue].append(bed)

def processTranscript(bed, clsMetaBySid, transcriptMetas, allBeds,
                      bedsBySid, bedsByStageTissue):
    transcriptMeta = transcriptMetas[bed.name]
    if transcriptMeta.artifact == 'no':
        makeTranscriptBed(bed, clsMetaBySid, transcriptMeta, allBeds, bedsBySid,
                          bedsByStageTissue)

def buildBeds(masterBedFile, clsMetaBySid, transcriptMetas):
    allBeds = []
    bedsBySid = defaultdict(list)
    bedsByStageTissue = defaultdict(list)
    for bed in BedReader(masterBedFile):
        processTranscript(bed, clsMetaBySid, transcriptMetas, allBeds, bedsBySid,
                          bedsByStageTissue)
    return allBeds, bedsBySid, bedsByStageTissue

def write_bed_file(beds, bedfile):
    beds.sort(key=Bed.genome_sort_key)
    fileOps.ensureFileDir(bedfile)
    with fileOps.opengz(bedfile, 'w') as fh:
        for bed in beds:
            bed.write(fh)

def write_bed_files(clsMetaBySid, allBeds, bedsBySid, bedsByStageTissue, trackDir):
    write_bed_file(allBeds, allModelsFile(trackDir, '.bed.gz'))
    for sid, beds in bedsBySid.items():
        write_bed_file(beds,
                       osp.join(trackDir, sampleIdToBedPath(clsMetaBySid[sid])))
    for stageTissue, beds in bedsByStageTissue.items():
        write_bed_file(beds,
                       osp.join(trackDir, stageTissueToBedPath(stageTissue)))

def makeBeds(opts, clsMetadataTab, masterBedFile, transcriptMetaTsv, trackDir):
    clsMetaBySid = loadClsMetadata(clsMetadataTab)
    transcriptMetas = loadTranscriptMeta(transcriptMetaTsv)
    allBeds, bedsBySid, bedsByStageTissue = buildBeds(masterBedFile, clsMetaBySid, transcriptMetas)
    write_bed_files(clsMetaBySid, allBeds, bedsBySid, bedsByStageTissue, trackDir)

##
# trackDb generate
##
COMPOSITE_TDB = """\
track clsLongReadRnaTrack
{parent}\
compositeTrack on
shortLabel CLS long-read RNAs
longLabel Capture long-seq long-read lncRNAs
type bigBed 12
visibility pack
subGroup1 view Views \\
          targets_view=Targets \\
          models_view=Models \\
          sample_models_view=Sample_models \\
          per_expr_models_view=Per-experiment_models \\
          per_expr_reads_view=Per-experiment_reads
subGroup2 sample Sample \\
          combined=Combined \\
{sampleSubgroups}
subGroup3 type Type \\
          targets=Targets \\
          models=Models \\
          pre_capture_ont_models=Pre-capture_ONT_models \\
          pre_capture_pacbio_models=Pre-capture_PacBio_models \\
          post_capture_ont_models=Post-capture_ONT_models \\
          post_capture_pacbio_models=Post-capture_PacBio_models \\
          pre_capture_ont_reads=Pre-capture_ONT_reads \\
          pre_capture_pacbio_reads=Pre-capture_PacBio_reads \\
          post_capture_ont_reads=Post-capture_ONT_reads \\
          post_capture_pacbio_reads=Post-capture_PacBio_reads
dimensions dimX=type dimY=sample
html clsLongReadRna.html

"""

TARGETS_TDB = """\
    track targets_view
    view targets_view
    parent clsLongReadRnaTrack
    shortLabel Targets
    visibility dense
    type bigBed
    noScoreFilter on

        track target_regions
        parent targets_view off
        subGroups view=targets_view sample=combined type=targets
        shortLabel CLS targets
        longLabel CLS target regions
        labelFields name
        defaultLabelFields none
        type bigBed 6 +
        visibility hide
        color 0,100,0
        bigDataUrl {bigDataUrl}
        priority 1

"""

MODELS_TDB = """\
    track models_view
    view models_view
    parent clsLongReadRnaTrack
    shortLabel Models
    visibility pack
    type bigBed
    noScoreFilter on

        track cls_gene_models
        parent models_view on
        subGroups view=models_view sample=combined type=models
        shortLabel CLS transcripts
        longLabel CLS transcript models
        type bigBed 12 +
        bigDataUrl {bigDataUrl}
        color 12,12,120
        visibility pack
        priority 2

"""

SAMPLE_MODELS_VIEW_TDB = """\
    track sample_models_view
    view sample_models_view
    parent clsLongReadRnaTrack
    shortLabel Sample models
    visibility squish
    type bigBed
    noScoreFilter on

"""

SAMPLE_MODELS_TDB = """\
        track {sample}_models
        parent sample_models_view on
        shortLabel {shortLabel}
        longLabel {longLabel}
        bigDataUrl {bigDataUrl}
        type bigBed 12 +
        color 100,22,180
        visibility squish
        subGroups view=sample_models_view sample={sample} type=models

"""

EXPR_READS_VIEW_TDB = """\
    track per_expr_reads_view
    view per_expr_reads_view
    parent clsLongReadRnaTrack on
    shortLabel Reads
    visibility squish
    type bam

"""

EXPR_READS_TDB = """\
        track {track}
        parent per_expr_reads_view off
        shortLabel {shortLabel}
        longLabel {longLabel}
        bigDataUrl {bigDataUrl}
        type bam
        visibility hide
        subGroups view=per_expr_reads_view sample={sample} type={type}

"""

EXPR_MODELS_VIEW_TDB = """\
    track per_expr_models_view
    view per_expr_models_view
    parent clsLongReadRnaTrack on
    shortLabel Models
    visibility dense
    type bigBed
    noScoreFilter on

"""

EXPR_MODELS_TDB = """\
        track {track}
        parent per_expr_models_view off
        shortLabel {shortLabel}
        longLabel {longLabel}
        bigDataUrl {bigDataUrl}
        type bigBed 12 +
        noScoreFilter on
        color 100,22,180
        visibility hide
        subGroups view=per_expr_models_view sample={sample} type={type}

"""

def sampleGrpName(clsMeta):
    return clsMeta.stage.lower() + "_" + clsMeta.tissue.lower()

def buildSampleSubGroupDef(clsMeta):
    name = sampleGrpName(clsMeta)
    desc = clsMetaToSampleDesc(clsMeta).replace(' ', '_')
    return f"          {name}={desc}"

def buildSampleSubGroupDefs(clsMetas):
    defs = []
    done = set()
    for clsMeta in clsMetas:
        sdef = buildSampleSubGroupDef(clsMeta)
        if sdef not in done:
            defs.append(sdef)
            done.add(sdef)
    return " \\\n".join(defs)


def getTrackName(clsMeta, dtype):
    # post_ccap_cpool_pb_reads
    return (clsMeta.stage + '_' + clsMeta.tissue + '_' + platform_desc_map[clsMeta.platform] + '_' +
            shortCapture(clsMeta).lower() + '_' + dtype).lower()

def getShortLabel(clsMeta, dtype):
    return clsMetaToDesc(clsMeta, short=True) + " " + dtype

def getLongLabel(clsMeta, dtype):
    if dtype == "models":
        dtype = "transcript models"
    return clsMetaToDesc(clsMeta) + " " + dtype

def getType(clsMeta, dtype):
    return (clsMeta.capture.replace('-', '_').lower() + '_' + platform_desc_map[clsMeta.platform].lower() +
            '_' + dtype)

def write_expr_data_trackdb(fh, clsMeta, dtype, bigDataUrl, template):
    fh.write(template.format(track=getTrackName(clsMeta, dtype),
                             sample=sampleGrpName(clsMeta),
                             stage=stageTissueName(clsMeta),
                             shortLabel=getShortLabel(clsMeta, dtype),
                             longLabel=getLongLabel(clsMeta, dtype),
                             type=getType(clsMeta, dtype),
                             bigDataUrl=bigDataUrl))

def write_sample_model_trackdb(fh, clsMetaST, bigDataUrl):
    "unique stage/tissue clsMeta"
    fh.write(SAMPLE_MODELS_TDB.format(sample=sampleGrpName(clsMetaST),
                                      shortLabel=clsMetaToSampleDesc(clsMetaST, short=True) + " models",
                                      longLabel=clsMetaToSampleDesc(clsMetaST, short=False) + " transcript models",
                                      bigDataUrl=bigDataUrl))

def filter_stage_tissue_samples(clsMetas):
    # only returns clsMetas on per stage/tissue
    clsMetaSubset = []
    done = set()
    for clsMeta in clsMetas:
        stageTissue = (stageTissueName(clsMeta), sampleGrpName(clsMeta))
        if stageTissue not in done:
            clsMetaSubset.append(clsMeta)
            done.add(stageTissue)
    return clsMetaSubset

def writeTrackDb(fh, clsMetas, trackDir, parent):
    parentSpec = '' if parent is None else f'parent {parent}\n'

    fh.write(COMPOSITE_TDB.format(sampleSubgroups=buildSampleSubGroupDefs(clsMetas),
                                  parent=parentSpec))
    fh.write(TARGETS_TDB.format(bigDataUrl=osp.join(trackDir, "cls-targets.bb")))
    fh.write(MODELS_TDB.format(bigDataUrl=osp.join(trackDir, "cls-models.bb")))

    fh.write(SAMPLE_MODELS_VIEW_TDB)
    for clsMeta in filter_stage_tissue_samples(clsMetas):
        write_sample_model_trackdb(fh, clsMeta,
                                   osp.join(trackDir, stageTissueToBigBedPath(stageTissueName(clsMeta))))

    fh.write(EXPR_MODELS_VIEW_TDB)
    for clsMeta in clsMetas:
        write_expr_data_trackdb(fh, clsMeta, 'models',
                                osp.join(trackDir, sampleIdToBigBedPath(clsMeta)),
                                EXPR_MODELS_TDB)

    fh.write(EXPR_READS_VIEW_TDB)
    for clsMeta in clsMetas:
        write_expr_data_trackdb(fh, clsMeta, 'reads',
                                osp.join(trackDir, sampleIdToBamPath(clsMeta)),
                                EXPR_READS_TDB)

def metaSortKey(clsMeta):
    return (clsMeta.tissue, clsMeta.stage, clsMeta.sid)

def makeTrackDb(opts, clsMetadataTab, trackDir, trackDb, parent):
    clsMetas = sorted(loadClsMetadata(clsMetadataTab).values(), key=metaSortKey)
    with open(trackDb, 'w') as fh:
        writeTrackDb(fh, clsMetas, trackDir, parent)

##
# check for BAMs matching metadata
##
def listBams(trackDir, capture):
    pat = osp.join(trackDir, capture, '*.bam')
    bams = glob.glob(pat)
    if len(bams) == 0:
        raise Exception(f"no BAMs found matching {pat}")
    return frozenset(bams)

def checkSampleWithBams(clsMeta, trackDir, bams, doneBams):
    sampleBam = sampleIdToBamPath(clsMeta, trackDir)
    if sampleBam not in bams:
        print(f"Error: {clsMeta.sid} BAM not found: {sampleBam}", file=sys.stderr)
        return False
    doneBams.add(sampleBam)
    bai = sampleBam + '.bai'
    if not osp.exists(bai):
        print(f"Error: {clsMeta.sid} no BAI file found for: {sampleBam}", file=sys.stderr)
        return False
    return True

def checkSamplesWithBams(clsMetaBySid, trackDir, bams):
    doneBams = set()
    errCnt = 0
    for clsMeta in clsMetaBySid.values():
        if not checkSampleWithBams(clsMeta, trackDir, bams, doneBams):
            errCnt += 1
    return errCnt, doneBams

def checkReferencedBams(bams, doneBams):
    errCnt = 0
    for bam in bams:
        if bam not in doneBams:
            print(f"Error: BAM not referenced by metadata: {bam}", file=sys.stderr)
            errCnt += 1
    return errCnt

def checkBams(opts, clsMetadataTab, trackDir):
    clsMetaBySid = loadClsMetadata(clsMetadataTab)
    bams = listBams(trackDir, "pre-capture") | listBams(trackDir, "post-capture")
    errCnt, doneBams = checkSamplesWithBams(clsMetaBySid, trackDir, bams)
    errCnt += checkReferencedBams(bams, doneBams)
    if errCnt > 0:
        print(f"{errCnt} errors detected", file=sys.stderr)
        exit(1)
    print(f"{len(doneBams)} BAMs validated", file=sys.stderr)


def main():
    opts, args = parseArgs()
    if args.command == 'make-beds':
        makeBeds(opts, args.clsMetadataTab, args.masterBedFile, args.transcriptMetaTsv, args.trackDir)
    elif args.command == 'make-trackdb':
        makeTrackDb(opts, args.clsMetadataTab, args.trackDir, args.trackDb, args.parent)
    elif args.command == 'check-bams':
        checkBams(opts, args.clsMetadataTab, args.trackDir)
    else:
        assert False, f"subcommand not handled: {args.command}"

main()
