#!/cluster/software/bin/python3

"""
Report database and track usage time, or plot the same.

Relies on the following files existing:
/hive/users/chmalee/logs/byDate/result/hgw{1-6}/*.{dbUsage,trackUsage}.gz
"""

import sys, os, argparse, re, glob, gzip
import datetime as dt
import time
from collections import defaultdict,Counter

today = dt.datetime.now()
today = dt.datetime(today.year, today.month, today.day)
lastWeek = today - dt.timedelta(days=7)

dataDir = "/hive/users/chmalee/logs/byDate/result/"
machList = ["hgw" + str(x) for x in range(1,7)] + ["asiaNode", "euroNode"]

#{'hg19': {2020-01-02 : 10, 2020-01-03: 11, ...}, 'hg38': {...}}
dbMatrix = defaultdict(Counter)
#{'hg19': {2020-01-02 : {'knownGene' : 10, 'refGene': 10}}, {2020-01-03: {'knownGene': 11, ...}},...}
trackMatrix = defaultdict(lambda: defaultdict(Counter))

def parseCommandLine():
    """Set up command line options"""
    parser = argparse.ArgumentParser(
        description = "Reports database and/or track usage over time to stdout, and optionally charts the usage over time", add_help=True, prefix_chars = "-", usage = "%(prog)s [options]")
    
    # optional filters
    parser.add_argument("-t", "--trackCounts", action="store_true", default=False,
        help="Report/plot track usage info.")
    parser.add_argument("-d", "--dbCounts", action="store_true", default=False,
        help="Report/plot db usage info.")
    parser.add_argument("-db", "--database", action='store', default=None,
        help="Only plot usage for this database. Can be quoted comma separated list of databases, for example -db \"hg38,hg19,wuhCor1\". The keywords all,ucsc,hubs can be used to get counts for all databases, ucsc only, and hub only.")
    parser.add_argument("--track", action="store", default=None,
        help="Only plot usage for this one track. Can also be a quoted comma sep list of tracks, for example: --track \"refGene,knownGene\". You can do prefix or suffix matching with the '*' character, for example: --track \"clinGen*,*Gencode*\"")
    parser.add_argument("-s", "--startDate", action="store", default=lastWeek,
        help="Starting date in format YYYY-MM-DD. Defaults to last week from today")
    parser.add_argument("-e", "--endDate", action="store", default=today,
        help="Ending date in format YYYY-MM-DD. Defaults to today")
    parser.add_argument("-p", "--make-plot", action="store_true", default=False,
        help="Make usage graphs with --save-prefix filename prefix")
    parser.add_argument("--save-prefix", action="store", default=None,
        help="Save output .png with this filename prefix.")
    parser.add_argument("--bin-months", action="store_true", default=False,
        help="Bin data into month columns.")

    args = parser.parse_args()
    if args.startDate != lastWeek:
        args.startDate = dt.datetime.strptime(args.startDate, "%Y-%m-%d")
    if args.endDate != today:
        args.endDate = dt.datetime.strptime(args.endDate, "%Y-%m-%d")
    if not args.trackCounts and not args.dbCounts:
        sys.stderr.write("Error: please specify one of -d/--dbCounts or -t/--trackCounts.\n")
        parser.print_help()
        sys.exit(1)
    return args

def getDbOrTrackList(db=False):
    """Return the list of files to parse over."""
    ret = []
    for m in machList:
        dirname = dataDir+m
        if db:
            fileNames = glob.glob(dirname+"/*.dbUsage.*")
        else:
            fileNames = glob.glob(dirname+"/*.trackUsage.*")
        ret += fileNames
    return ret

def isFileInRange(logDateStr, startDate, endDate):
    logDateStart = dt.datetime.strptime(logDateStr, "%Y%m%d")
    logDateRange = logDateStart + dt.timedelta(days=7)
    return (logDateStart >= startDate and logDateStart <= endDate) or \
        (logDateStart < startDate and logDateRange >= startDate)

def formValidDbs(dbArg):
    """Split a command line supplied db list to a real list."""
    if not dbArg or (dbArg == "all" or dbArg == "ucsc" or dbArg == "hubs"):
        return None
    csep = dbArg.split(',')
    return [db.strip() for db in csep]

def formValidTracks(tracks):
    if not tracks:
        return None
    return set([x.strip() for x in tracks.split(',')])

def parseDbFile(infh, args):
    """Parse a trimmed.dbUsage.gz file into dbMatrix."""
    dateList = []
    dbList = formValidDbs(args.database)
    for line in infh:
        if "str" not in str(type(line)):
            line = line.decode("ASCII")
        if line.startswith('#'):
            dates = line.strip().split('\t')[1:]
            for d in dates:
                newDate = dt.datetime.strptime(d, "%Y-%m-%d")
                if newDate >= args.startDate and newDate <= args.endDate:
                    dateList.append(newDate)
        else:
            dbCounts = line.strip().split('\t')
            db = dbCounts[0]
            if not dbList or (dbList and db in dbList):
                c = Counter(dict(zip(dateList, [int(x) for x in dbCounts[1:]])))
                dbMatrix[db].update(c)
    return dateList

def isTrackNameMatch(trackName, cmdLineTrackList):
    if trackName in cmdLineTrackList:
        return True
    else:
        if "*" in "".join(cmdLineTrackList):
            import re
            for cmdTrack in cmdLineTrackList:
                if re.match(cmdTrack,trackName):
                    return True
    return False

def parseTrackFile(infh, args):
    """Parse a trimmed.trackUsage.gz file into trackMatrix"""
    dateList = []
    dbList = formValidDbs(args.database)
    trackList = formValidTracks(args.track)
    lineNumber = 1
    try:
        for line in infh:
            if "str" not in str(type(line)):
                line = line.decode("ASCII")
            if line.startswith('#'):
                dates = line.strip().split('\t')[2:]
                for d in dates:
                    newDate = dt.datetime.strptime(d, "%Y-%m-%d")
                    if newDate >= args.startDate and newDate <= args.endDate:
                        dateList.append(newDate)
            else:
                dbTrackCounts = line.strip().split('\t')
                try:
                    db = dbTrackCounts[0]
                    track = dbTrackCounts[1]
                except:
                    sys.exit("Error: dbTrackCounts: '%s', infile: '%s', line number: '%s'" % (dbTrackCounts, infh.name, lineNumber))
                if not dbList or (dbList and db in dbList):
                    if not trackList or isTrackNameMatch(track, trackList):
                        c = Counter(dict(zip(dateList, [int(x) for x in dbTrackCounts[2:]])))
                        trackMatrix[db][track].update(c)
            lineNumber += 1
    except OSError as e:
        sys.exit("Error: %s not a valid gzipped file, original message:\n%s" % (infh.name, e))
    return dateList

def parseFileList(args, fileList):
    """Trim fileList based on dates"""
    ret = []
    for fname in fileList:
        dateStr = fname.split('/')[-1].split('.')[0]
        # check whether we should we actually look at this file:
        if isFileInRange(dateStr, args.startDate, args.endDate):
            ret += [fname]
    return ret

def getCounts(args):
    """Open up input files based on command line and get data."""
    dateList = set()
    if args.dbCounts:
        fnames = getDbOrTrackList(db=True)
        fileList = parseFileList(args, fnames)
        for fname in fileList:
            with gzip.open(fname, "rb") as f:
                datesCovered = parseDbFile(f, args)
                dateList.update(datesCovered)
    if args.trackCounts:
        fnames = getDbOrTrackList(db=False)
        fileList = parseFileList(args, fnames)
        for fname in fileList:
            with gzip.open(fname, "rb") as f:
                datesCovered = parseTrackFile(f, args)
                dateList.update(datesCovered)
    return sorted(dateList)

def maybeBinDbData(args, dbData, datesFound):
    """Bin data into progressively larger bins depending on date range requested.
       If < 2 weeks of data requested, just return
       If > 2 weeks and <= 8 weeks of data requested, bin every two days
       If > 8 weeks and <= 16 weeks of data requested, bin every 3 days
          > 16 weeks and <= 26 weeks of data requested, bin every 5 days
          > 27 weeks of data requested, bin every week
      """
    dateDiff = args.endDate - args.startDate
    numDaysRequested = dateDiff.days + 1
    dateRange = [args.startDate + dt.timedelta(days=x) for x in range((args.endDate - args.startDate).days + 1)]
    newData = defaultdict(dict)
    binSize = 0
    if numDaysRequested <= (7 * 2):
        return dateRange, dbData
    elif numDaysRequested > (7 * 2)and numDaysRequested <= (7 * 8):
        newKeys = [args.startDate + dt.timedelta(days=x) for x in range(0,numDaysRequested,2)]
        binSize = 2
    elif numDaysRequested > (7 * 8) and numDaysRequested <= (7 * 16):
        newKeys = [args.startDate + dt.timedelta(days=x) for x in range(0,numDaysRequested,3)]
        binSize = 3
    elif numDaysRequested > (7 * 16) and numDaysRequested <= (7 * 26):
        newKeys = [args.startDate + dt.timedelta(days=x) for x in range(0,numDaysRequested,5)]
        binSize = 5
    else:
        newKeys = [args.startDate + dt.timedelta(days=x) for x in range(0,numDaysRequested,7)]
        binSize = 7

    for db in dbData:
        for newDate in newKeys:
            endDate = newDate + dt.timedelta(days=binSize-1)
            rangeSum = sum(dbData[db][day] for day in dbData[db] if day >= newDate and day <= endDate)
            newData[db][newDate] = rangeSum
    
    return newKeys, newData

def makeDbList(args):
    dbs = []
    if not args.database or args.database == "all":
        dbs = trackMatrix.keys()
    elif args.database == "ucsc":
        dbs = [db for db in trackMatrix.keys() if not db.startswith("hub_")]
    elif args.database == "hubs":
        dbs = [db for db in trackMatrix.keys() if db.startswith("hub_")]
    else:
        dbs = formValidDbs(args.database)
    return dbs
    

def makeTrackData(args, datesFound):
    """Make the track usage specific report information."""
    # Make the header
    dateList = []
    if (args.bin_months):
        dateList = sorted(set([dt.datetime.strftime(x, "%Y-%m") for x in datesFound]))
    else:
        dateList = sorted(set([dt.datetime.strftime(x, "%Y-%m-%d") for x in datesFound]))
    print("#Date\ttrack\t" + "\t".join(dateList))
    for db in makeDbList(args):
        for track in trackMatrix[db]:
            countList = trackMatrix[db][track]
            s = db
            if (args.bin_months):
                s += "\t%s" % (track)
                for yearMonth in dateList:
                    total = 0
                    for day in trackMatrix[db][track]:
                        if dt.datetime.strftime(day, "%Y-%m") == yearMonth:
                            total += trackMatrix[db][track][day]
                    s += "\t" + str(total)
            else:
                s += "\t%s\t" % (track) + "\t".join([str(countList[dt.datetime.strptime(x, "%Y-%m-%d")]) for x in dateList])
            print(s)

def makeDbData(args, datesFound):
    """Make the db usage specific report information."""
    dbs = []
    if not args.database or args.database == "all":
        dbs = dbMatrix.keys()
    elif args.database == "ucsc":
        dbs = [db for db in dbMatrix.keys() if not db.startswith("hub_")]
    elif args.database == "hubs":
        dbs = [db for db in dbMatrix.keys() if db.startswith("hub_")]
    else:
        dbs = formValidDbs(args.database)
    data = sorted(dbs, key = lambda key: sum(dbMatrix[key][val] for val in dbMatrix[key]), reverse=True)
    # if no db's were specified then only return the top 15. If some were specified
    # then data has the ones specified already
    dbList = data
    if not args.database:
        dbList = data[:15]

    # print the report
    print("#Date\t" + "\t".join(dbList))
    for day in datesFound:
        row = "%s\t" % (dt.datetime.strftime(day, "%a %b %d, %Y"))
        row += "\t".join([str(dbMatrix[db][day]) for db in dbList])
        print("%s" % (row))
    totalsStr = "#total%s" % ("s" if len(dbList) > 1 else "")
    for db in dbList:
        totalsStr += "\t%s" % (str(sum(dbMatrix[db][day] for day in dbMatrix[db])))
    print(totalsStr)

    # return the data in the report for plotting if we're doing that
    return {key: dbMatrix[key] for key in dbList}

def getReportData(args, datesFound):
    """Find the requested or top 15 data to graph"""
    missingDataStart = dt.datetime(year=2020, month=1,day=1)
    requestedDateRange = [args.startDate + dt.timedelta(days=x) for x in range((args.endDate - args.startDate).days + 1)]
    if  requestedDateRange[0] <= missingDataStart and requestedDateRange[-1] >= missingDataStart:
        sys.stderr.write("#warning: missing data from Jan 1 - Jan 4 2020\n")

    # the below routines print their reports to stdout
    dbData = None
    trackData = None
    if (args.dbCounts):
        dbData = makeDbData(args, datesFound)
    if (args.trackCounts):
        makeTrackData(args, datesFound)
    return dbData, trackData

def graphCounts(args, dbData, trackData, datesFound):
    """Use matplotlib and graph dbMatrix or trackMatrix or both."""
    import numpy as np
    import pandas as pd # for handling date ranges
    import matplotlib as mpl
    mpl.use('Agg')
    import matplotlib.pyplot as plt
    import matplotlib.dates as mdates
    from matplotlib.ticker import MultipleLocator as ml
    plt.style.use('style')
    dbArrays = {}
    fig,ax = plt.subplots()
    
    binnedDateRange, binnedDbData = maybeBinDbData(args, dbData, datesFound)
    startDate = binnedDateRange[0]
    endDate = binnedDateRange[-1]
    binAmount = (binnedDateRange[1] - startDate).days
    ax.set_xlim(startDate, endDate)
    for db in binnedDbData:
        x = []
        y = []
        for day in sorted(binnedDbData[db]):
            x.append(day)
            y.append(binnedDbData[db][day])
        xvals = np.array(x)
        yvals = np.array(y)
        ax.plot(xvals, yvals, label="%s total hits: %d" % (db,sum(y)))

    if binAmount < 2:
        ax.xaxis.set_minor_locator(mdates.DayLocator()) # tick every day
    else:
        pdrange = pd.date_range(startDate, endDate)
        xticks = [dt.datetime(x.year, x.month, x.day) for x in  pd.to_datetime(np.linspace(pdrange[0].value, pdrange[-1].value, num = 8))]
        ax.set_xticks(xticks) # label every binned day but only label 8 of them, evenly spaced
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
    ax.set_xlabel("Day" + " (binned every %d days)" % (binAmount) if binAmount > 1 else "Day")
    ax.set_ylabel("Usage")

    if args.database:
        dbList = formValidDbs(args.database)
        if len(dbList) == 1:
            ax.set_title("%s usage between %s and %s" % ("".join(dbList), startDate.strftime("%b %d, %Y"), endDate.strftime("%b %d, %Y")))
        else:
            ax.set_title("Database usage between %s and %s" % (startDate.strftime("%b %d, %Y"), endDate.strftime("%b %d, %Y")))
    else:
        ax.set_title("Top 15 Databases used between %s and %s" % (startDate.strftime("%b %d, %Y"), endDate.strftime("%b %d, %Y")))
        
    ax.legend()
    saveName="/cluster/home/chmalee/public_html/logs/dbUsage/"
    if args.save_prefix:
        saveName += args.save_prefix + ".dbUsage.png" if args.save_prefix[0] != "." else "dbUsage.png"
    else:
        saveName += "dbUsage.png"
    plt.savefig(saveName, dpi=600)

def makeReport(args, datesFound):
    """Make a text report and graph the data if requested."""
    dbData, trackData = getReportData(args, datesFound)
    if args.make_plot:
        graphCounts(args, dbData, trackData, datesFound)

def main():
    args = parseCommandLine()
    datesFound = getCounts(args)
    makeReport(args, datesFound)

if __name__ == "__main__":
    main()
