#!/usr/bin/env python3

import subprocess
from collections import OrderedDict
from time import localtime, strftime
from datetime import datetime
from dateutil.relativedelta import relativedelta
import calendar
import math
import subprocess,sys,argparse,os

def parseArgs():
    """
    Parse the command line arguments.
    """
    parser = argparse.ArgumentParser(description = __doc__,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    optional = parser._action_groups.pop()

    required = parser.add_argument_group('required arguments')

    required.add_argument ("dbs",
        help = "Database to query for track counts, e.g. hg19, hg38, mm10.")
    required.add_argument ("workDir",
        help = "Work directory to use for processing and final output. Use full path with '/' at the end.")
    optional.add_argument ("-c", "--cutOffThreshhold", dest = "cutOffThreshhold", default = .3,
        help = "Optional: The % value, as compared to the trackCounts median, to be used " + \
            "as a threshhold to choose what tracks should be filtered. Default is .3.")
    optional.add_argument ("-n", "--numOfMonthsToCompare", dest = "numOfMonthsToCompare", default = 6,
        help = "Optional: The number of months to compare for filtering. " + \
            "Default is 6.")
    optional.add_argument ("-r", "--singleReport", default = False, action = "store_true",
        help = "Optional: Run as a singleReport. This generates track counts for the specified " + \
            "dates. This is useful for seeing track counts over a period of time. Requires the vars below")
    optional.add_argument ("-s", "--startDate", dest = "startDate",
        help = "Optional: The start date when running in singleReport mode. " + \
            "Date should be formatted as YYYY-MM-DD.")
    optional.add_argument ("-e", "--endDate", dest = "endDate",
        help = "Optional: The end date when running in singleReport mode. " + \
            "Date should be formatted as YYYY-MM-DD.")
    if (len(sys.argv) == 1):
        parser.print_usage()
        print("\nGenerates track counts based on the log parsing paring script '/hive/users/chmalee/logs/byDate/makeUsageReport'.\n" + \
              "The default behavior looks at track counts over the last 6 months and generates a list of tracks that\n" + \
              "were below 30% of the median track usage every month, allowing for a floor(n/4) exception where n is\n" + \
              "the number of months searched. Using the default 6 months, that allows for 1 exemption.\n" + \
              "This output can be used in order to identify seldomly used tracks for archiving, retiring or restructuring.\n\n" + \
              "Alternatively, the script can be run in 'singleReport(-r)' mode where it will generate a sorted list\n" + \
              "of track counts over a period of time. This could be useful for reporting purposes.\n\n" + \

              "Example runs:\n" + \
              "    trackCountsParse hg38 /hive/users/lrnassar/trackCounts/\n" + \
              "    trackCountsParse hg38 /hive/users/lrnassar/trackCounts/ -c .5 -n 12\n" + \
              "    trackCountsParse hg38 /hive/users/lrnassar/trackCounts/ -r -s 2023-01-01 -e 2023-12-31\n")
        
        exit(0)
    parser._action_groups.append(optional)
    options = parser.parse_args()
    return  options

def bash(cmd):
    """Run the cmd in bash subprocess"""
    try:
        rawBashOutput = subprocess.run(cmd, check=True, shell=True,\
                                       stdout=subprocess.PIPE, universal_newlines=True, stderr=subprocess.STDOUT)
        bashStdoutt = rawBashOutput.stdout
    except subprocess.CalledProcessError as e:
        raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
    return(bashStdoutt)

def file_exists(filepath):
    return os.path.isfile(filepath)

def generateTrackCounts(dbs,workDir,startDate,endDate):
    #Generate track usage report binned by month for a specific time frame and dbs
    #Format date format is XXXX-XX-XX, e.g. 2023-11-01
    outputFileName = workDir+dbs+"."+startDate+"to"+endDate+".trackCounts.txt"
    #Run the script and remove the ct line (ct_), hub lines (hub_), dup lines (dup_)
    #And header line (^#), as well as remove the first column of repeating database
    
    if not file_exists(outputFileName):
        print("Generating new file")
        print("/hive/users/chmalee/logs/byDate/makeUsageReport -t -db " + dbs + " --bin-months -s " + startDate + " -e " + endDate + " | grep -v \"hub_\\|ct_\\|dup_\\|^#\" | cut -f2- > " + outputFileName)
        cmd = ("/hive/users/chmalee/logs/byDate/makeUsageReport -t -db " + dbs + " --bin-months -s " + startDate + " -e " + endDate + " | grep -v \"hub_\\|ct_\\|dup_\\|^#\" | cut -f2- > " + outputFileName)
        bash(cmd)
    else:
        print("File already generated: " + outputFileName)
        
    return(outputFileName)

def createDicFromTrackCountsFile(trackCountsFilePath):
    """
    This function reads a file containing track counts and creates a dictionary 
    with track names as keys and their corresponding counts as values.
    
    Args:
    - trackCountsFilePath (str): Path to the file containing track counts.
    
    Returns:
    - dict: A dictionary containing track names as keys and their counts as values.
    """
    trackList = open(trackCountsFilePath, 'r')
    trackCountsDic = {}
    for line in trackList:
        parsedLine = line.rstrip().split("\t")
        if parsedLine[0].endswith(":"):
            currentTrackName = parsedLine[0][0:len(parsedLine[0])-1]
        else:
            currentTrackName = parsedLine[0]
        if currentTrackName not in trackCountsDic.keys():
            trackCountsDic[currentTrackName] = {}
            trackCountsDic[currentTrackName]["Count"] = parsedLine[1]
        else:
            trackCountsDic[currentTrackName]["Count"] = str(int(trackCountsDic[currentTrackName]["Count"]) + int(parsedLine[1]))
    trackList.close()
    totalCount = len(trackCountsDic.keys())
    print("Total tracks to parse: "+str(totalCount))
    return(trackCountsDic,totalCount)

def checkIfThereIsAHigherLevelParentTrack(parentChildAssociationsDic,trackName,firstTry,dbs):
    """
    Recursive function that searches for deeper associations and builds dictionary
    """
    if firstTry == False:
        currentParentName = parentChildAssociationsDic[trackName]['parentName']
        tdbQuery = bash("tdbQuery \"select * from "+dbs+" where track='"+currentParentName+"'\"").split("\n")
    else:
        tdbQuery = bash("tdbQuery \"select * from "+dbs+" where track='"+trackName+"'\"").split("\n")
    if "compositeTrack" in str(tdbQuery) or "superTrack" in str(tdbQuery)  or "parent" in str(tdbQuery):
        for entry in tdbQuery:        
            if entry.startswith("parent"):
                parentChildAssociationsDic[trackName]['Container'] = True #I ADDED THIS LINE IF THERE ARE ISSUES
                parentName = entry.split(" ")[1]
                parentChildAssociationsDic[trackName]['parentName'] = parentName
            elif entry.startswith("superTrack"):
                entry = entry.split(" ")
                if entry[1] != "on":
                    parentChildAssociationsDic[trackName]['Container'] = True
                    parentName = entry[1]
                    parentChildAssociationsDic[trackName]['parentName'] = parentName
    else:
        if firstTry != False:
            parentChildAssociationsDic[trackName]['Container'] = False
    return(parentChildAssociationsDic)

def lookUpTracksToFindParentChildAssociations(trackCountsDic,totalCount,parentChildAssociationsDic,dbs):
    """
    This function takes a dictionary of track names and their details, 
    queries a database to find parent-child relationships for each track,
    and updates the dictionary with the associated parent track information.
    
    Args:
    - trackCountsDic (dict): A dictionary containing track names and their details.
    
    Returns:
    - dict: An updated dictionary containing parent-child association information.
    """
    n=0
    for trackName in trackCountsDic.keys():
        if trackName not in parentChildAssociationsDic.keys():
            parentChildAssociationsDic[trackName] = {}
            n+=1
            if n%2000 == 0:
                print(str(n)+" out of "+str(totalCount))
            #Make a first check to see if there are parents
            parentChildAssociationsDic = checkIfThereIsAHigherLevelParentTrack(parentChildAssociationsDic,trackName,True,dbs)
            #Check to see if there is a higher level parent
            if 'parentName' in parentChildAssociationsDic[trackName].keys():
                parentChildAssociationsDic = checkIfThereIsAHigherLevelParentTrack(parentChildAssociationsDic,trackName,False,dbs)
            #The top level tracks have container on, but no parent
            else:
                parentChildAssociationsDic[trackName]['Container'] = False
            #Check to see if there is a final higher level parent
            if 'parentName' in parentChildAssociationsDic[trackName].keys():
                parentChildAssociationsDic = checkIfThereIsAHigherLevelParentTrack(parentChildAssociationsDic,trackName,False,dbs)        

    return(parentChildAssociationsDic)

def buildFinalDicWithOnlyTopLevelTrackCounts(trackCountsDic,parentChildAssociationsDic):
    """
    Iterate through the dictionary of all track counts + parental relationships
    and create a final dic that only includes all possible top-level tracks
    with the highest possible count from any of its children.
    """
    finalDicOfTopLevelTracksAndCounts = {}
    for trackName in trackCountsDic.keys():
        if parentChildAssociationsDic[trackName]['Container'] is False:
            if trackName not in finalDicOfTopLevelTracksAndCounts.keys():
                finalDicOfTopLevelTracksAndCounts[trackName] = trackCountsDic[trackName]['Count']
            elif int(trackCountsDic[trackName]['Count']) > int(finalDicOfTopLevelTracksAndCounts[trackName]):
                finalDicOfTopLevelTracksAndCounts[trackName] = trackCountsDic[trackName]['Count']
        else:
            if parentChildAssociationsDic[trackName]['parentName'] not in finalDicOfTopLevelTracksAndCounts.keys():
                finalDicOfTopLevelTracksAndCounts[parentChildAssociationsDic[trackName]['parentName']] = trackCountsDic[trackName]['Count']
            else:
                if int(trackCountsDic[trackName]['Count']) > int(finalDicOfTopLevelTracksAndCounts[parentChildAssociationsDic[trackName]['parentName']]):
                    finalDicOfTopLevelTracksAndCounts[parentChildAssociationsDic[trackName]['parentName']] = trackCountsDic[trackName]['Count']
    return(finalDicOfTopLevelTracksAndCounts)

def makeFinalFileOnTopLevelTrackCounts(finalDicOfTopLevelTracksAndCounts,pathUrl,dbs):
    """
    This function creates a final output file containing details of top-level tracks,
    including their short labels (if available) and counts. The file is saved at the
    specified path URL.
    
    Args:
    - finalDicOfTopLevelTracksAndCounts (dict): A dictionary containing top-level track names 
                                                and their respective counts.
    - pathUrl (str): The path where the final output file will be saved.
    """
    outputFile = open(pathUrl, 'w')
    n=0
    for key in finalDicOfTopLevelTracksAndCounts.keys():
        if key != "":
            tdbQuery = bash("tdbQuery \"select * from "+dbs+" where track='"+key+"'\"").split("\n")
            if "shortLabel" in str(tdbQuery):
                for entry in tdbQuery:        
                    if entry.startswith("shortLabel"):
                        n+=1
                        shortLabel = " ".join(entry.split(" ")[1:])
                        outputFile.write(key+"\t"+shortLabel+"\t"+str(finalDicOfTopLevelTracksAndCounts[key])+"\n")
            else:
                n+=1
                outputFile.write(key+"\t"+key+"\t"+str(finalDicOfTopLevelTracksAndCounts[key])+"\n")
    print("Final file completed. Tota number of tracks: "+str(n))
    outputFile.close()
    #Order and sort final file
    bash("sort -t $'\t' -k3 -rn "+pathUrl+" > "+pathUrl+".sorted")
    print("Final sorted file: "+pathUrl+".sorted")

def get_count(dicField):
    return dicField['trackCounts']

def makeOrderedDicForSpecificTime(finalDicOfTopLevelTracksAndCounts,dbs):
    orderedDic = OrderedDict()
    listOfCountsToSort = []
    n=0
    #Fetch the shortLabels and make a list where each entry is a dic with trackName, shortLabel, and trackCount
    for key in finalDicOfTopLevelTracksAndCounts.keys():
        if key != "":
            tdbQuery = bash("tdbQuery \"select * from "+dbs+" where track='"+key+"'\"").split("\n")
            if "shortLabel" in str(tdbQuery):
                for entry in tdbQuery:        
                    if entry.startswith("shortLabel"):
                        n+=1
                        shortLabel = " ".join(entry.split(" ")[1:])
                        listOfCountsToSort.append({'trackName':key,'shortLabel':shortLabel,'trackCounts':int(finalDicOfTopLevelTracksAndCounts[key])})
            else:
                n+=1
                listOfCountsToSort.append({'trackName':key,'shortLabel':key,'trackCounts':int(finalDicOfTopLevelTracksAndCounts[key])})
    print("Total number of tracks: "+str(n))
    
    #Sort the trackCounts list to return in order to find data that meets cutoff threshhold
    listOfCountsToSort.sort(key=get_count, reverse=True)
    return(listOfCountsToSort)
    
def refineTrackCountsBasedOnCutOff(listOfTracks,cutOffThreshhold,period):
    """
    Take an ordered list containing dics of track counts and filter it based
    on a set threshhold. Then return a new ordered dictionary that contains
    the tracks below the threshhold with trackNames as keys and shortLabel
    and counts as values
    """
    trackCountCutoff = listOfTracks[int(len(listOfTracks)/2)]['trackCounts']*cutOffThreshhold
    finalTrackCountsDic = OrderedDict()
    for track in listOfTracks:
        if track['trackCounts'] < trackCountCutoff:
            finalTrackCountsDic[track['trackName']]={'shortLabel':track['shortLabel'],'trackCounts':track['trackCounts'],'countComparedToMaxForPeriod':track['trackCounts']/listOfTracks[0]["trackCounts"],'countComparedToCutoff':track['trackCounts']/trackCountCutoff}
    print("The trackCount cutoff for "+period+" is: "+str(trackCountCutoff))
    return(finalTrackCountsDic)

def getDateRangesForComparison(numOfMonthsToCompare):
    """
    Based on a number of months to compare given, find the year + month combination
    followed by the last day of each month to be used in the log query script. Return
    an ordered dictionary with the date ranges as keys, which will be used as the titles
    of the respective final ouputs, and the start/end dates as the content.
    **Note** This subtracts an additional month from the latest month in order
    to ensure that the logs chosen are complete.
    """
    dateRanges = OrderedDict()
    date = datetime.today().strftime('%Y-%m')
    for number in range(numOfMonthsToCompare):
        monthToParse = datetime.strftime(datetime.strptime(date, '%Y-%m') - relativedelta(months=number+1), '%Y-%m')
        year = monthToParse.split('-')[0]
        month = monthToParse.split('-')[1]
        lastDateOfMonth = calendar.monthrange(int(year), int(month))[1]
        startDate = year+"-"+month+"-01"
        endDate = year+"-"+month+"-"+str(lastDateOfMonth)
        dateRanges[startDate+"-"+endDate] = {'startDate':startDate,'endDate':endDate}
    return(dateRanges)

def createFinalListOfTracksThatMeetCutoffEveryMonth(finalDicWithCutOffDics,numOfMonthsToCompare):
    """
    Iterates through all of the monthly dictionaries and creates a final list
    where only tracks present in every period are present. This is to filter
    out monthly outliers. It then reports which ones were filtered, if any.
    Due to data being weird, if we are checking at least 4 months allow
    for outliers of floor(n/4)
    """
    initialTrackList = {}
    listOfTracksInEligiblePeriods = []
    listOfTracksFilteredOut = []
    #Create initial list of all tracks present in these periods
    for period in finalDicWithCutOffDics.keys():
        for track in finalDicWithCutOffDics[period].keys():
            if track not in initialTrackList:
                initialTrackList[track] = 0
    #Due to data being weird, if we are checking at least 4 months allow
    #for outliers of floor(n/4)
    outlierMonthsExcemption = math.floor(numOfMonthsToCompare/4)
    #Go over the list and add a penalty of 1 for every period in which
    #the track is missing, then create a final list of tracks that
    #pass through the filter
    for period in finalDicWithCutOffDics.keys():
        for track in initialTrackList.keys():
            if track not in finalDicWithCutOffDics[period].keys():
                initialTrackList[track] = initialTrackList[track] + 1
    for track in initialTrackList.keys():
        if initialTrackList[track] <= outlierMonthsExcemption:
            listOfTracksInEligiblePeriods.append(track)
        else:
            listOfTracksFilteredOut.append(track)
    if listOfTracksFilteredOut != []:
        print("The following tracks were filtered out because they did not meet")
        print("the cutoff in all of the months specified:\n")
        for track in listOfTracksFilteredOut:
            print(track)
    return(listOfTracksInEligiblePeriods)

def get_avCount(dicField):
    """Helper function to help sort dict"""
    return dicField['averageTrackCount']

def constructSortedFinalTrackDicWithAllData(finalDicWithCutOffDics,listOfTracksInEligiblePeriods,numOfMonthsToCompare,dbs):
    """
    Takes in a final list of track names which have met all the conditions for potential archiving
    and constructs a final dictionary with all of the data sorted. This includes averages over
    the time period for the track counts as well as the comparison to median/max. The track group
    is also queried for use in the ultimate decision.
    """
    firstPeriod = next(iter(finalDicWithCutOffDics))
    listOfTracksToReport = []
    for track in listOfTracksInEligiblePeriods:
        group = ""
        addedTrackCounts = 0
        addedCountComparedToMaxForPeriod = 0
        addedCountComparedToCutoff = 0      
        missingMonths = 0 #This checks for the floor(n/4) tolerance that tracks can be missing
        for period in finalDicWithCutOffDics.keys():
            if track in finalDicWithCutOffDics[period].keys():
                if group == "":
                    shortLabel = finalDicWithCutOffDics[period][track]["shortLabel"]
                    tdbQuery = bash("tdbQuery \"select group from "+dbs+" where track='"+track+"'\"").split("\n")
                    if 'group' in str(tdbQuery):
                        group = tdbQuery[0].split(" ")[1]
                    else: 
                        group = "No group"
                addedTrackCounts+=finalDicWithCutOffDics[period][track]['trackCounts']
                addedCountComparedToMaxForPeriod+=finalDicWithCutOffDics[period][track]['countComparedToMaxForPeriod']
                addedCountComparedToCutoff+=finalDicWithCutOffDics[period][track]['countComparedToCutoff']
            else:
                missingMonths+=1
                
        averageTrackCount = round(addedTrackCounts/(numOfMonthsToCompare-missingMonths),2)
        averageCountComparedToMaxForPeriod = round(addedCountComparedToMaxForPeriod/(numOfMonthsToCompare-missingMonths),4)
        averageCountComparedToCutoff = round(addedCountComparedToCutoff/(numOfMonthsToCompare-missingMonths),2)
        listOfTracksToReport.append({'trackName':track,'shortLabel':shortLabel,'group':group,'averageTrackCount':averageTrackCount,'averageCountComparedToMaxForPeriod':averageCountComparedToMaxForPeriod,'averageCountComparedToCutoff':averageCountComparedToCutoff})
    listOfTracksToReport.sort(key=get_avCount, reverse=True)
    return(listOfTracksToReport)

def writeFinalTrackListToFile(finalOutputTrackDicToReport,workDir,dbs,cutOffThreshhold,numOfMonthsToCompare):
    """
    Take the final processed dictionary and write it out to a tsv file including the vars
    used in the data generation.
    """
    date = datetime.today().strftime('%Y-%m')
    monthToParse = datetime.strftime(datetime.strptime(date, '%Y-%m') - relativedelta(months=numOfMonthsToCompare), '%Y-%m')
    fileNamePathStartToEndDate = workDir+monthToParse+"-"+date+"."+dbs+".tracksToArchive.tsv"
    finalOutputFile = open(fileNamePathStartToEndDate,'w')
    finalOutputFile.write("#Variables used in this file generation: dbs="+dbs+" numOfMonthsToCompare="+str(numOfMonthsToCompare)+" cutOffThreshhold="+str(cutOffThreshhold)+"\n")
    finalOutputFile.write("#trackName\tshortLabel\tgroup\taverageTrackCount\taverageCountComparedToMaxForPeriod\taverageCountComparedToCutoff\n")
    for track in finalOutputTrackDicToReport:
        finalOutputFile.write(track['trackName']+"\t"+track['shortLabel']+"\t"+track['group']+"\t"+str(track['averageTrackCount'])+"\t"+str(track['averageCountComparedToMaxForPeriod'])+"\t"+str(track['averageCountComparedToCutoff'])+"\n")
    finalOutputFile.close()
    print("\nCutoff tracks file complete: "+fileNamePathStartToEndDate)
    print("\nYou nicely format the output as such: tail -n +2 outputFilePath.tsv | tabFmt stdin")

def main():
    """Initialize options and call other functions"""
    options = parseArgs()
    dbs,workDir,cutOffThreshhold,numOfMonthsToCompare = options.dbs,options.workDir,options.cutOffThreshhold,options.numOfMonthsToCompare
    #Line below exists only for debugging purposes
    # dbs,workDir,cutOffThreshhold,numOfMonthsToCompare,singleReport = 'hg38','/hive/users/lrnassar/temp/tmp/',.3,6,False
    if singleReport == True:
        startDate,endDate,parentChildAssociationsDic = options.startDate,options.endDate,{}
        logFile = generateTrackCounts(dbs,workDir,startDate,endDate)
        trackCountsDic,totalCount = createDicFromTrackCountsFile(logFile)
        parentChildAssociationsDic = lookUpTracksToFindParentChildAssociations(trackCountsDic,totalCount,parentChildAssociationsDic,dbs)
        finalDicOfTopLevelTracksAndCounts = buildFinalDicWithOnlyTopLevelTrackCounts(trackCountsDic,parentChildAssociationsDic)
        makeFinalFileOnTopLevelTrackCounts(finalDicOfTopLevelTracksAndCounts,workDir+"trackCounts.tsv",dbs)    

    else:
        print("Script started: "+strftime("%Y-%m-%d %H:%M:%S", localtime()))
        dateRanges = getDateRangesForComparison(numOfMonthsToCompare)
        finalDicWithCutOffDics,parentChildAssociationsDic = OrderedDict(),{}
        for period in dateRanges:
            logFile = generateTrackCounts(dbs,workDir,dateRanges[period]['startDate'],dateRanges[period]['endDate'])
            trackCountsDic,totalCount = createDicFromTrackCountsFile(logFile)
            parentChildAssociationsDic = lookUpTracksToFindParentChildAssociations(trackCountsDic,totalCount,parentChildAssociationsDic,dbs)
            finalDicOfTopLevelTracksAndCounts = buildFinalDicWithOnlyTopLevelTrackCounts(trackCountsDic,parentChildAssociationsDic)
            listOfTracks = makeOrderedDicForSpecificTime(finalDicOfTopLevelTracksAndCounts,dbs)
            finalDicWithCutOffDics[period] = refineTrackCountsBasedOnCutOff(listOfTracks,cutOffThreshhold,period)
            print("The number of tracks that met the criteria for "+period+" is: "+str(len(finalDicWithCutOffDics[period].keys())))
        
        listOfTracksInEligiblePeriods = createFinalListOfTracksThatMeetCutoffEveryMonth(finalDicWithCutOffDics,numOfMonthsToCompare)
        finalOutputTrackDicToReport = constructSortedFinalTrackDicWithAllData(finalDicWithCutOffDics,listOfTracksInEligiblePeriods,numOfMonthsToCompare,dbs)
        writeFinalTrackListToFile(finalOutputTrackDicToReport,workDir,dbs,cutOffThreshhold,numOfMonthsToCompare)
        print("Script finished: "+strftime("%Y-%m-%d %H:%M:%S", localtime()))

main()
