#!/usr/bin/env python3
# A beacon allows very limited queries against a set of variants without allowing someone
# to download the list of variants
# see ga4gh.org/#/beacon (UCSC redmine 14393)

import cgi, subprocess, sys, cgitb, os, socket, time, json, glob, urllib.parse, time
import sqlite3, gzip, optparse, gc, string, re, socket
from os.path import join, isfile, dirname, isdir, basename, exists

# general info about this beacon
hostName = os.environ.get("HTTP_HOST", "localhost")
homepage = "http://%s/cgi-bin/hgBeacon" % hostName
beaconDesc = {
    "id":"ucsc-browser", \
    "name":"Genome Browser", \
    "organization" : "UCSC", \
    "description" : "UCSC Genome Browser",
    "api" : "0.2",
    "homepage": homepage
    }
# the list genome reference we support
beaconReferences = ["GRCh37"]

# sqlite database 
SQLITEDB = "beaconData.sqlite"

# dir of the CGIs relative to REQUEST_URI
#cgiDir = ".." # symlinks are one level below the main cgi-bin dir
cgiDir = "." # symlinks are in the current dir, for GA4GH

# cache of hg.conf dict
hgConf = None

# descriptions of datasets that this beacon is serving
DataSetDescs = {
"hgmd" : "Human Genome Variation Database, only single-nucleotide variants, public version, provided by Biobase",
"lovd" : "Leiden Open Varation Database installations that agreed to share their variants, only single-nucleotide variants and deletions",
"test" : "small test data on Chromosome 1, from ICGC",
"test2" : "another piece of test data from Chromosome 1, also from ICGC"
}

# special case: same datasets do not have alt alleles. In this case, an overlap is enough to trigger a "true"
NoAltDataSets = ["hgmd"]

def queryBottleneck(host, port, ip):
    " contact UCSC-style bottleneck server to get current delay time "
    # send ip address
    s =  socket.socket()
    s.connect((host, int(port)))
    msg = ip.encode("latin1")
    d = chr(len(msg)).encode("latin1")+msg
    s.send(d)

    # read delay time
    expLen = ord(s.recv(1))
    totalLen = 0
    buf = list()
    while True:
        resp = s.recv(1024)
        buf.append(resp)
        totalLen+= len(resp)
        if totalLen==expLen:
            break
    return int(b"".join(buf))

def parseConf(fname):
    " parse a hg.conf style file, return as dict key -> value (both are strings) "
    conf = {}
    for line in open(fname):
        line = line.strip()
        if line.startswith("#"):
            continue
        elif line.startswith("include "):
            inclFname = line.split()[1]
            inclPath = join(dirname(fname), inclFname)
            if isfile(inclPath):
                inclDict = parseConf(inclPath)
                conf.update(inclDict)
        elif "=" in line: # string search for "="
            key, value = line.split("=", 1)
            conf[key] = value
    return conf

def parseHgConf(confDir="."):
    """ return hg.conf as dict key:value """
    global hgConf
    if hgConf is not None:
        return hgConf

    hgConf = dict() # python dict = hash table

    currDir = dirname(__file__)
    fname = join(currDir, confDir, "hg.conf")
    if not isfile(fname):
        fname = join(currDir, "hg.conf")
    if not isfile(fname):
        return {}
    hgConf = parseConf(fname)

    return hgConf

def errAbort(errMsg=None):
    " exit with error message "
    if errMsg == None:
        sys.exit(0)

    helpUrl = "http://%s/cgi-bin/hgBeacon?page=help" % hostName

    ret = {"errormsg":errMsg, 
        "more_info":"for a complete description of the parameters, read the help message at %s" % helpUrl}
    printJson(ret)
    sys.exit(0)

def printHelp():
    " print help text to stdout and exit "
    print("Content-Type: text/html\n")
    print("<html><body>")
    host = hostName # convert from global to local var
    if host.endswith(".ucsc.edu"):
        helpDir = "/gbdb/hg19/beacon"
    else:
        helpDir = dirname(__file__)

    helpPath = join(helpDir, "help.txt")
    if not exists(helpPath):
        errAbort("no file %s found. The beacon is not activated on this machine" % helpPath)

    helpText = open(helpPath).read()
    print(helpText % locals())
    print("</body></html>")
    sys.exit(0)

def dataSetResources():
    " Returns the list of DataSetResources "
    totalSize = 0
    dsrList = []
    conn = dbOpen(mustExist=True)
    for tableName in dbListTables(conn):
        rows = dbQuery(conn, "SELECT COUNT(*) from %s" % tableName, None)
        itemCount = rows[0][0]
        
        # the dataset ID is just the file basename without extension
        dsId = tableName
        dsr = (dsId, DataSetDescs.get(dsId, ""), itemCount)
        dsrList.append(dsr)
        totalSize += itemCount
    return totalSize, dsrList

def beaconInfo():
    " return a beaconInfo dict "
    size, dsrList = dataSetResources()
    return \
        {
        "beacon": beaconDesc,
        "references": beaconReferences,
        "datasets": dsrList,
        "size": size
        }

def printJson(data):
    print("Content-Type: application/json\n")
    print(json.dumps(data, indent=4, sort_keys=True,separators=(',', ': ')))

def hgBotDelay():
    " implement bottleneck delay, get bottleneck server from hg.conf "
    global hgConf
    hgConf = parseHgConf("..")
    if "bottleneck.host" not in hgConf:
        return
    ip = os.environ["REMOTE_ADDR"]
    delay = queryBottleneck(hgConf["bottleneck.host"], hgConf["bottleneck.port"], ip)
    if delay>10000:
        time.sleep(delay/1000.0)
    if delay>20000:
        errAbort("IP blocked")
        sys.exit(0)

def checkParams(chrom, pos, allele, reference, track):
    " make sure the parameters follow the spec "
    if reference==None or reference=="":
        reference="GRCh37"
    if reference not in beaconReferences:
        errAbort("invalid 'reference' parameter, valid ones are %s" % ",".join(beaconReferences))
    if chrom==None or chrom=="":
        errAbort("missing chromosome parameter")
    if allele==None or allele=="":
        errAbort("missing alternateBases parameter")

    allele = allele.upper()
    if not (re.compile("[ACTG]+").match(allele)!=None or \
       re.compile("I[ACTG]+").match(allele)!=None or \
       re.compile("D[0-9]+").match(allele)!=None):
        errAbort("invalid alternateBases parameter, can only be a [ACTG]+ or I[ACTG]+ or D[0-9]+")
    if track is not None and track!="":
        if not track.isalnum():
            errAbort("'dataset' parameter must contain only alphanumeric characters")
        if len(track)>100:
            errAbort("'dataset' parameter must not be longer than 100 chars")

    if pos==None or not pos.isdigit():
        errAbort("'position' parameter is not a number")
    pos = int(pos)

    # convert chrom to hg19 format
    if chrom==None:
        errAbort( "missing chromosome parameter")
    if not ((chrom.isdigit() and int(chrom)>=1 and int(chrom)<=22) or chrom in ["X","Y","M","test"]):
        errAbort( "invalid chromosome name %s" % chrom)

    return chrom, pos, allele, reference, track

def lookupAllele(chrom, pos, allele, reference, dataset):
    " check if an allele is present in a sqlite DB "
    conn = dbOpen(mustExist=True)
    tableList = dbListTables(conn)
    if dataset!=None and dataset!="":
        if dataset not in tableList:
            errAbort("dataset %s is not present on this server" % dataset)
        tableList = [dataset]

    for tableName in tableList:
        cur = conn.cursor()
        if tableName in NoAltDataSets:
            # some datasets don't have alt alleles, e.g. HGMD
            sql = "SELECT * from %s WHERE chrom=? AND pos=?" % tableName
            cur.execute(sql, (chrom, pos))
        else:
            sql = "SELECT * from %s WHERE chrom=? AND pos=? AND allele=?" % tableName
            cur.execute(sql, (chrom, pos, allele))
        row = cur.fetchone()
        if row!=None:
            return "true"

    return "false"

def lookupAlleleJson(chrom, pos, altBases, refBases, reference, dataset):
    " call lookupAllele and wrap the result into dictionaries "
    chrom, pos, altBases, reference, dataset = checkParams(chrom, pos, altBases, reference, dataset)

    exists = lookupAllele(chrom, pos, altBases, reference, dataset)

    if chrom=="test" and pos==0:
        exists = "true"

    query = {
        "alternateBases": altBases,
        "referenceBases" : refBases,
        "chromosome": chrom.replace("chr",""),
        "position": pos,
        "reference": reference,
        "dataset": dataset
        }
    ret = {"beacon" : beaconDesc, "query" : query, "response" : {"exists":exists} }
    return ret
    
def main():
    # detect if running under apache or was run from command line
    if 'REQUEST_METHOD' in os.environ:
        if not (hostName.startswith("hgw") and hostName.endswith("ucsc.edu")) \
            or hostName.startswith("hgwdev."):
            # enable special CGI error handler not on the RR, but on hgwdev
            cgitb.enable()
        mainCgi()
    else:
        mainCommandLine()

def parseArgs():
    " parse command line options into args and options "
    parser = optparse.OptionParser("""usage: %prog [options] filename [datasetName] - import VCF or BED files into the beacon database. 
    - parameter 'datasetName' is optional and defaults to 'defaultDataset'.
    - any existing dataset of the same name will be overwritten
    - the data is written to beaconData.sqlite. You can use 'sqlite3' to inspect the data file.
    - the input file can be gzipped

    e.g. to load HGMD:
    /usr/local/apache/cgi-bin/hgBeacon -f hgmd /tmp/temp.bed hgmd
    """)

    parser.add_option("-d", "--debug", dest="debug", \
        action="store_true", help="show debug messages") 
    parser.add_option("-f", "--format", dest="format", \
        action="store", help="format of input file, one of vcf, lovd, hgmd. default %default", \
        default = "vcf") 
    (options, args) = parser.parse_args()

    if len(args)==0:
        parser.print_help()
        sys.exit(0)
    return args, options

def dbMakeTable(conn, tableName):
    " create an empty table with chrom/pos/allele fields "
    conn.execute("DROP TABLE IF EXISTS %s" % tableName)
    conn.commit()

    _tableDef = (
    'CREATE TABLE IF NOT EXISTS %s '
    '('
    '  chrom text,' # chromosome
    '  pos int,' # start position, 0-based
    '  allele text' # alternate allele, can also be IATG = insertion of ATG or D15 = deletion of 15 bp
    ')')
    conn.execute(_tableDef % tableName)
    conn.commit()
    #'  PRIMARY KEY (chrom, pos, allele) '

def dbOpen(mustExist=False):
    " open the sqlite db and return a DB connection object "
    dbDir = dirname(__file__) # directory where script is located
    fqdn = socket.getfqdn()
    if hostName.endswith("ucsc.edu") or fqdn.endswith(".ucsc.edu") or fqdn.endswith(".sdsc.edu"):
        # at UCSC, the data is not in the CGI-BIN directory
        dbDir = "/gbdb/hg19/beacon/"

    dbName = join(dbDir, SQLITEDB)

    if not exists(dbName) and mustExist:
        errAbort("Database file %s does not exist. This beacon is not serving any data." % dbName)
    conn = sqlite3.Connection(dbName)
    return conn

def dbListTables(conn):
    " return list of tables in sqlite db "
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    rows = cursor.fetchall()
    tables = []
    for row in rows:
        tables.append(row[0])
    return tables
    
def dbQuery(conn, query, params):
    cursor = conn.cursor()
    if params==None:
        cursor.execute(query)
    else:
        cursor.execute(query, params)
    return cursor.fetchall()

def readAllelesVcf(ifh):
    """ read alleles in VCF file 
        return a list of chrom, pos, allele tuples
    """
    doneData = set() # copy of all data, to check for duplicates
    rows = []
    skipCount = 0
    emptyCount = 0
    gc.disable()
    for line in ifh:
        if line.startswith("#"):
            continue
        fields = string.split(line.rstrip("\n"), "\t", maxsplit=5)
        chrom, pos, varId, ref, alt = fields[:5]
        pos = int(pos)-1 # VCF is 1-based, beacon is 0-based
        
        if alt==".":
            emptyCount += 1
            continue

        refIsOne = len(ref)==1
        altIsOne = len(alt)==1

        if refIsOne and altIsOne:
            # single bp subst
            beaconAllele = alt
        elif not refIsOne and altIsOne:
            # deletion
            beaconAllele = "D"+str(len(ref)-1) # skip first nucleotide, VCF always adds one nucleotide
            pos += 1 
        elif refIsOne and not altIsOne:
            # insertion
            beaconAllele = "I"+alt[1:]
            pos += 1
        elif not refIsOne and not altIsOne:
            skipCount +=1
        else:
            print("Error: invalid VCF fields:", fields)
            sys.exit(1)

        if len(rows) % 500000 == 0:
            print("Read %d rows..." % len(rows))

        dataRow = (chrom, pos, beaconAllele)
        if dataRow in doneData:
            continue

        rows.append( dataRow )
        doneData.add( dataRow )

    print("skipped %d VCF lines with empty ALT alleles" % emptyCount)
    print("skipped %d VCF lines with both ALT and REF alleles len!=1, cannot encode as beacon queries" % skipCount)
    return rows

def readAllelesLovd(ifh):
    """ read the LOVD bed file and return in format (chrom, pos, altAllele) 
    This function is only used internally at UCSC.
    """
    alleles = []
    count = 0
    skipCount = 0
    for line in ifh:
        if line.startswith("chrom"):
            continue
        chrom, start, end, desc = line.rstrip("\n").split("\t")[:4]

        if desc[-2]==">":
            mutDesc = desc[-3:]
            ref, _, alt = mutDesc
            assert(len(mutDesc)==3)
        elif desc.endswith("del"):
            alt = "D"+str(int(end)-int(start))
        else:
            skipCount += 1
            continue

        chrom = chrom.replace("chr", "")
        start = int(start)
        alleles.append( (chrom, start, alt) )

    print("read %d alleles, skipped %d non-SNV or del alleles" % (len(alleles), skipCount))
    return list(set(alleles))

def readAllelesHgmd(ifh):
    """ read the HGMD bed file and return in format (chrom, pos, altAllele).
    This function is only used internally at UCSC.
    """
    # chr1 2338004 2338005 PEX10:CM090797 0 2338004 2338005 PEX10 CM090797 substitution
    alleles = []
    count = 0
    skipCount = 0
    for line in ifh:
        fields = line.rstrip("\n").split("\t")
        chrom, start, end = fields[:3]
        desc = fields[11]
        start = int(start)
        end = int(end)

        if desc=="substitution" or desc=="splicing variant":
            assert(end-start==1)
            alt = "*"
        else:
            skipCount += 1
            continue

        chrom = chrom.replace("chr", "")
        alleles.append( (chrom, start, alt) )

    print("read %d alleles, skipped %d non-SNV alleles" % (len(alleles), skipCount))
    return list(set(alleles))

def iterChunks(seq, size):
    " yields chunks of at most size elements from seq "
    for pos in range(0, len(seq), size):
        yield seq[pos:pos + size]

def printTime(time1, time2, rowCount):
    timeDiff = time2 - time1
    print("Time: %f secs for %d rows, %d rows/sec" % (timeDiff, rowCount, rowCount/timeDiff))

def importFile(fileName, datasetName, format):
    """ open the sqlite db, create a table datasetName and write the data in fileName into it """
    conn = dbOpen()
    dbMakeTable(conn, datasetName)

    # try to make sqlite writes as fast as possible
    conn.execute("PRAGMA synchronous=OFF")
    conn.execute("PRAGMA count_changes=OFF") # http://blog.quibb.org/2010/08/fast-bulk-inserts-into-sqlite/
    conn.execute("PRAGMA cache_size=800000") # http://web.utk.edu/~jplyon/sqlite/SQLite_optimization_FAQ.html
    conn.execute("PRAGMA journal_mode=OFF") # http://www.sqlite.org/pragma.html#pragma_journal_mode
    conn.execute("PRAGMA temp_store=memory")
    conn.commit()

    if fileName.endswith(".gz"):
        ifh = gzip.open(fileName)
    else:
        ifh = open(fileName)

    # see http://stackoverflow.com/questions/1711631/improve-insert-per-second-performance-of-sqlite
    # for background why I do it like this
    print("Reading %s into database table %s" % (fileName, datasetName))
    rowCount = 0
    startTime = time.time()

    if format=="vcf":
        alleles = readAllelesVcf(ifh)
    elif format=="lovd":
        alleles = readAllelesLovd(ifh)
    elif format=="hgmd":
        alleles = readAllelesHgmd(ifh)
        

    loadTime = time.time()

    printTime(startTime, loadTime, len(alleles))

    print("Loading alleles into database")
    for rows in iterChunks(alleles, 50000):
        sql = "INSERT INTO %s (chrom, pos, allele) VALUES (?,?,?)" % datasetName
        conn.executemany(sql, rows)
        conn.commit()
        rowCount += len(rows)
    insertTime = time.time()
    printTime(loadTime, insertTime, len(alleles))

    print("Indexing database table")
    conn.execute("CREATE UNIQUE INDEX '%s_index' ON '%s' ('chrom', 'pos', 'allele')" % \
        (datasetName, datasetName))
    indexTime = time.time()
    printTime(insertTime, indexTime, len(alleles))

def mainCommandLine():
    " main function if called from command line "
    args, options = parseArgs()

    fileName = args[0]
    if len(args)==2:
        datasetName = args[1]
    else:
        datasetName = "defaultDataset"

    importFile(fileName, datasetName, options.format)
    
def mainCgi():

    url = os.environ["REQUEST_URI"]
    parsedUrl = urllib.parse.urlparse(url)

    # get CGI parameters
    form = cgi.FieldStorage()

    # react based on symlink that was used to call this script
    page = parsedUrl[2].split("/")[-1] # last part of path is REST endpoint
    if page=="info":
        printJson(beaconInfo())
        sys.exit(0)

    hgBotDelay()

    chrom = form.getfirst("chromosome")
    pos = form.getfirst("position")
    refBases = form.getfirst("referenceBases")
    altBases = form.getfirst("alternateBases")
    reference = form.getfirst("reference")
    dataset = form.getfirst("dataset")
    format = form.getfirst("format")

    if chrom==None and pos==None and altBases==None:
        printHelp()
        sys.exit(0)

    ret = lookupAlleleJson(chrom, pos, altBases, refBases, reference, dataset)
    if format=="text":
        print("Content-Type: text/html\n")
        print(ret["response"]["exists"])
    else:
        printJson(ret)

if __name__=="__main__":
    # deactivate this on the RR, but useful for debugging: prints a http header
    # on errors 
    main()
