#!/usr/bin/env python2.7

# create html/js files that describe a CDW dataset. Includes basic summary
# information, cell type assignment charts and cell clustering")

import logging, sys, optparse, gzip, marshal, json, urllib2, re, tempfile, os, math, urllib
import string, unicodedata, itertools
from collections import defaultdict, Counter, namedtuple
from os.path import join, basename, dirname, isfile, isdir, getsize

# http://stackoverflow.com/questions/1447287/format-floats-with-standard-json-module
from json import encoder
encoder.FLOAT_REPR = lambda o: format(o, '.3f')

# no need for garbage collection here
import gc
gc.disable()

# global var for command line options
options = None

# run seurat over all genes or just genes with symbols ?
seuratAllGenes = False

# height of seurat plot in pixels
seuratHeight = 800

# how to convert the meta.tab fields to our format
# and how to show them on the page
# format: metaTagName, myName, descriptionForHtmlPage

# The fields to summarize at the top of the page:
# Only counts of the values are shown
# myName is the name used in the summary.tmp data cache and
# the sumData dictionary below. It's often shorter and easier to type
summaryTags = [
('title', 'title', 'Title'),
('GEO_Series_overall_design', 'design', 'Design'),
('GEO_Series_summary', 'summary', 'Description'),
('species', 'species', 'Species'),
('organ', 'organs', 'Organ'),
('submission_date', 'submission_date', 'Submission Date'),
('body_part', 'body_part', 'Body Part'),
('life_stage', 'lifestage', 'Lifestage'),
('sequencer', 'sequencer', 'Sequencer'),
('cell_type', 'cell_type', 'Cell Type'),
("GEO_Sample_c1_chip_id", 'chipId', "Fluidigm Chip ID"),
("GEO_Series_geo_accession", 'geoAccs', "GEO Accession"),
]

# list of meta fields to show for individual cells in heatmap mouseover 
# and in the cell summary page and on the PCA points
# format is the same as above: external name, internal name, description
cellInfoTags = [
("biosample_source_age_value", "sampleAge", "Age"),
("body_part", "body_part", "Body Part"),
("GEO_Sample_experiment_sample_name", "sampleName", "Sample Name"),
("cell_type", "cell_type", "Cell Type"),
("sequencer", "sequencer", "Sequencer"),
("GEO_Sample_c1_chip_id", 'chipId', "Fluidigm Chip ID"),
("life_stage", "life_stage", "Life Stage")
]

# a list of ourName values where we show only the value
# that is most frequent in the meta.tab file
metaFieldsOnlyOne = ['title', 'summary', "submission_date"]

# the attribute name by which the seurat plot is colored by default
# can be changed with a command line option
defaultSeuratAttribute = "tsneCluster"

def getCirmStaticFile(relPath):
    """ make relPath relative to CIRM static data directory.
    The base path can be set with the 'CIRM' environment variable.
    """
    if platform.node()=="hgwdev":
        CIRMPATH = "/hive/groups/cirm/"
    else:
        CIRMPATH = "/pod/pstore/"
    return join(os.getenv("CIRM", CIRMPATH), relPath)

# === command line interface, options and help ===
def errAbort(msg):
    print "Fatal Error: "+msg
    sys.exit(1)

def parseCmdLine():
    parser = optparse.OptionParser(
    """usage: %prog [options] mysqlDatasetId matrixFname metaFname outDir - create html/js files
    that describes a CDW dataset. Includes basic summary information, cell type assignment
    charts and cell clustering

    matrixFname can be 'none' if no RNA-data is available.

    mysqlDatasetId it not necessary, can be any value.

    If Seurat is not installed on your machine, install it like this:
    Rscript -e 'install.packages("devtools", repos="http://cran.rstudio.com/"); \\
    library(devtools); source("http://bioconductor.org/biocLite.R"); \\
    install_url("https://github.com/satijalab/seurat/releases/download/v1.4.0/Seurat_1.4.0.5.tgz", binary = TRUE)'

    """)

    parser.add_option("-d", "--debug", dest="debug", action="store_true", help="show debug messages")
    parser.add_option("-m", "--markerFile", dest="markerFile", action="store", help="a file that defines marker genes for cell types, format: symbol,gene,cellType, default %default, the mouse version is /hive/groups/cirm/annotation/markers/darmanis2015/top500/mouseMarkers.tab", default=getCirmStaticFile("annotation/ensToHugo/barres.tab"))
    parser.add_option("", "--symbols", dest="useSyms", action="store_true", help="set this option if the rows in the input matrix file are symbols, not ENSGxxxx gene identifiers. This will read only symbols from the file specified with the -m option")
    parser.add_option("-o", "--offline", dest="offline", action="store_true", help="copy all required style and js files into the output directory")
    parser.add_option("-f", "--force", dest="force", action="store_true", help="do not use the results cache (e.g. summary.tmp) but recreate all input data")
    parser.add_option("-j", "--cellTreeJson", dest="cellTreeJson", action="store", help="a file with the D3-JSON-encoded cell clustering results, default %default", default=None)
    parser.add_option("-c", "--meta", dest="meta", action="store", help="name of the meta table column where the filename is stored. This field is used to link the meta table and RNA-seq output together. Default value:'%default'", default="meta")
    parser.add_option("-s", "--symTable", dest="symTable", action="store", help="use a table (transcriptId, geneId, symbol) to convert incoming gene identifiers to symbols, default %default, a version for mouse is /hive/groups/cirm/annotation/ensToHugo/genes.mm10", default=getCirmStaticFile("annotation/ensToHugo/genes.hg38"))
    parser.add_option("-a", "--showAllTags", dest="showAllTags", action="store_true", help="show all tags on mouse over, not just our favorite ones (cellInfoTags)")
    parser.add_option("", "--tagLabels", dest="tagLabels", action="store", help="short labels of tags, tab-sep file with two fields, tag name and label for it")
    parser.add_option("", "--seuratAllGenes", dest="seuratAllGenes", action="store_true", help="by default, Seurat is only run on genes that have a symbol. This parameter runs Seurat on all genes.")
    parser.add_option("", "--colors", dest="colors", action="store", help="json file with a nested dict. first key is meta attribute name, second key is a value of the attribute and the html name/hexcode of the color. Example: {'tissue':{'blood':'red'}}")
    parser.add_option("", "--colorBy", dest="colorBy", action="store", help="attribute to color the seurat plot on by default, default is %default", default=defaultSeuratAttribute)
    parser.add_option("-S", "--onlySeurat", dest="onlySeurat", action="store_true", help="create only the seurat plot")
    parser.add_option("", "--annoDir", dest="annovarDir", action="store", help="name of directory with annovar output files of VCFs")
    parser.add_option("", "--mani", dest="maniFname", action="store", help="if VCF mode: name of manifest file")
    parser.add_option("-p", "--pipeline", dest="pipeline", action="store", help="if VCF mode: name of pipeline for manifest file")
    parser.add_option("", "--seuratOut", dest="seuratOut", action="store", help="write Seurat data to json file")

    (options, args) = parser.parse_args()
    if args==[]:
        parser.print_help()
        exit(1)

    if options.debug:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    return args, options
# ==== global ========

cellTypeJsonFname = None

# ==== functions =====

def makeTmpDirFor(outDir):
    " make outDir/tmp and return "
    tmpDir = join(outDir, "tmp")
    if not isdir(tmpDir):
        os.makedirs(tmpDir)
        logging.info('Created %s' % tmpDir)
    return tmpDir

def lineFileNextRow(inFile):
    """
    parses tab-sep file with headers in first line
    yields collection.namedtuples
    strips "#"-prefix from header line
    """

    if isinstance(inFile, str):
        if inFile.endswith(".gz"):
            fh = gzip.open(inFile, 'rb')
        else:
            fh = open(inFile)
    else:
        fh = inFile

    line1 = fh.readline()
    line1 = line1.strip("\n").strip("#")
    headers = line1.split("\t")
    headers = [re.sub("[^a-zA-Z0-9_]","_", h) for h in headers]
    headers = [x if x!="" else "rowName" for x in headers]

    filtHeads = []
    for h in headers:
        if h[0].isdigit():
            filtHeads.append("x"+h)
        else:
            filtHeads.append(h)
    headers = filtHeads


    Record = namedtuple('tsvRec', headers)
    for line in fh:
        if line.startswith("#"):
            continue
        line = line.decode("latin1")
        # map special chars in meta data to most similar ASCII equivalent
        line = unicodedata.normalize('NFKD', line).encode('ascii','ignore')
        line = line.rstrip("\n")
        fields = string.split(line, "\t", maxsplit=len(headers)-1)
        try:
            rec = Record(*fields)
        except Exception, msg:
            logging.error("Exception occured while parsing line, %s" % msg)
            logging.error("Filename %s" % fh.name)
            logging.error("Line was: %s" % line)
            logging.error("Does number of fields match headers?")
            logging.error("Headers are: %s" % headers)
            raise Exception("header count: %d != field count: %d wrong field count in line %s" % (len(headers), len(fields), line))
        yield rec

def readGeneToSym(fname):
    " given a file with geneId,symbol return a dict geneId -> symbol "
    logging.info("Reading gene,symbol mapping from %s" % fname)

    # Jim's files have no headers, they are just key-value
    line1 = open(fname).readline()
    if "geneId" not in line1:
        d = parseDict(fname)
    # my own files have headers
    else:
        d = {}
        for row in lineFileNextRow(fname):
            d[row.geneId.split(".")[0]]=row.symbol
    return d

def dictToStr(d):
    """ parse a nested D3 dict and return as a comma-sep string of names """
    if d["name"]!= " ":
        return str(d["name"])
    else:
        #return "("+",".join(dictToStr(x) for x in d["children"])+")"
        return ",".join([dictToStr(x) for x in d["children"]])

def jsonTreeToList(fname):
    """
    parse a D3-style hierarchical tree and return as a list of names
    in the same order as in the file
    """
    data = json.load(open(fname))
    return dictToStr(data).split(",")

def writeHeader(ofh, title):
    " write html header part and body tag "
#<link rel="stylesheet" href="http://maxcdn.bootstrapcdn.com/bootstrap/3.3.5/css/bootstrap.min.css">
#<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/nvd3/1.8.1/nv.d3.min.css" >
#<script src="https://cdnjs.cloudflare.com/ajax/libs/nvd3/1.8.1/nv.d3.js"></script>
#<script src="https://cdnjs.cloudflare.com/ajax/libs/nvd3/1.8.1/nv.d3.min.js"></script>
#<script src="http://hgwdev.soe.ucsc.edu/~max/nvd3/build/nv.d3.min.js"></script>
    ofh.write("""
<!doctype html>
<html>
<head>
<meta charset="utf-8">
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css">
<link rel='STYLESHEET' href='/style/nice_menu.css' TYPE='text/css' />
<link rel='STYLESHEET' href='/style/HGStyle.css' TYPE='text/css' />
<link rel='STYLESHEET' href='/style/jquery-ui.css' TYPE='text/css' />

<script src="https://ajax.googleapis.com/ajax/libs/jquery/1.12.0/jquery.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.14/d3.min.js"></script>
<script src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
<script src="/js/jquery.plugins.js"></script>
<script src="/js/cdwSummaryJs.js"></script>
""")
    newNvd=False
    if newNvd:
        ofh.write("""<link rel="stylesheet" href="http://hgwdev.soe.ucsc.edu/~max/nvd3/build/nv.d3.min.css" >""")
        ofh.write("""<script src="http://hgwdev.soe.ucsc.edu/~max/nvd3/build/nv.d3.min.js"></script>""")
    else:
        ofh.write("""<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/nvd3/1.8.1/nv.d3.min.css" >""")
        ofh.write("""<script src="https://cdnjs.cloudflare.com/ajax/libs/nvd3/1.8.1/nv.d3.min.js"></script>""")

        #<script type="text/javascript" src="http://vaakash.github.io/jquery/collapser.js"></script>
        #<script src="http://d3plus.org/js/d3.min.js"></script>
        #<script src="http://d3plus.org/js/d3plus.min.js"></script>
    if not options.offline:
        ofh.write("""
        <script src="//www.kryogenix.org/code/browser/sorttable/sorttable.js"></script>
        <script src="//demos.flesler.com/jquery/scrollTo/js/jquery.scrollTo-min.js"></script>
        """)

    # the string <!--menuBar--> will get replaced with out menu bar html code by cdwGetFile

    ofh.write("""
    <title>%s</title>
</head>

<body>

<!--menuBar-->

<style>
#mainContent {
   margin-left: 20px;  
   color : #000;
   }

   table {
    background:#D9F8E4;
    border-collapse:collapse
   }

   thead tr {
     background-color:#1616D1;
     color:#FFFFFF;
     vertical-align: top;
     text-align: left
   }

   table, th, tr, td {
    border: 1px solid black;
    padding: 5px;
   }
</style>

<div id="mainContent">

    """ % title)
    # <script type='text/javascript' SRC='../../js/jquery.js'></script>
    #<script type='text/javascript' SRC='../../js/jquery.plugins.js'></script>
    #<script type='text/javascript' SRC='../../js/jquery-ui.js'></script>
    #<script type='text/javascript' SRC='../../js/ajax.js'></script>

def writeCellTypeSection(ofh, jsonFname):
    
    jsonData = open(jsonFname, "r").read()
    ofh.write("""
<h2>Predicted Cell Types</h2>
<div id="pieChart" style="width:300px; height:300px"><svg style="width:300px; height:300px"></svg></div>
<script>
pieData = %s;


//Donut chart example
nv.addGraph(function() {
  var pieChartFunc = nv.models.pieChart()
      .x(function(d) { return d.label })
      .y(function(d) { return d.value })
      .showLegend(false)
      .showLabels(true)     //Display pie labels
      ;

    d3.select("#pieChart svg")
        .datum(pieData)
        .transition().duration(1)
        .call(pieChartFunc);
  nv.utils.windowResize(pieChartFunc.update);

  return pieChartFunc;
});

</script>""" % jsonData)

def parseMarker(fname, useSyms):
    """ input file has format symbol,transcriptId,geneId,cellType
    returns dict cellType -> list of gene ids
    and gene ID -> set of cell types
    """
    if fname=="none":
        logging.info("No marker gene list, using all genes")
        return None, None
    logging.info("Reading marker gene list from %s" % fname)
    ifh = open(fname)
    line1 = ifh.readline()
    fs = line1.rstrip("\n").split("\t")
    assert(fs[0]=="symbol")
    assert(fs[1]=="gene")
    assert(fs[2]=="cellType")

    typeToGenes = defaultdict(set)
    geneToTypes = defaultdict(set)
    for line in ifh:
        fs = line.rstrip("\n").split("\t")
        sym, geneId, cellType = fs[:3]
        if useSyms:
            geneId = sym
        else:
            geneId = geneId.split(".")[0]

        typeToGenes[cellType].add(geneId)
        geneToTypes[geneId].add(cellType)

    # convert to a dict string -> list
    geneToTypeList = {}
    for geneId, typeSet in geneToTypes.iteritems():
        geneToTypeList[geneId] = list(typeSet)

    return typeToGenes, geneToTypeList

def parseDict(fname):
    """ parse text file in format key<tab>value and return as dict key->val """
    d = {}

    if fname.endswith(".gz"):
        fh = gzip.open(fname)
    else:
        fh = open(fname)

    for line in fh:
        key, val = line.rstrip("\n").split("\t")
        d[key] = val
    return d

def matrixToCellType(exprMat, markers, outDir):
    """
    try to guess the cell type for each cell based on marker genes
    return as dict cellId -> cellType
    """
    cellTypeFname = join(outDir, "cellTypes.log")
    cellTypesJson = join(outDir, "cellTypes.json")
    logOfh = open(cellTypeFname, "w")

    cellToType = {}
    for cellName, exprDict in exprMat.iteritems():
        # for each cell type make dictionary of marker gene -> count
        cellOvls = []
        for cellType, markerSet in markers.iteritems():
            overlapTrans = {}
            for markerId in markerSet:
                if markerId in exprDict:
                    overlapTrans[markerId] = exprDict[markerId]
            cellOvls.append( (len(overlapTrans), cellType, overlapTrans) )
        cellOvls.sort(reverse=True)

        topOvlLen, topCellType, _ = cellOvls[0]
        row = [cellName, topCellType, str(topOvlLen)]
        cellToType[cellName] = topCellType

        # create string describing how we got to this result
        descs = []
        for ovlLen, cellType, transDict in cellOvls:
            strList = ["%s=%f" % (trans, count) for trans,count in transDict.iteritems()]
            descs.append("%s:%d:%s" % (cellType, ovlLen, ",".join(strList)))
        row.append("|".join(descs))

        logOfh.write("\t".join(row))
        logOfh.write("\n")
    logging.info("Assignment written to %s" % logOfh.name)
    return cellToType

def convTabToMarshal(matrixFname):
    " convert a genes-down-cells-right matrix to a .marshal file. "
    marshFname = matrixFname+".marshal"
    cellTransCounts = defaultdict(dict)
    #newFname = matrixFname.replace(".geneTpm.tab", "")
    logging.info("Converting %s to %s" % (matrixFname, marshFname))
    logging.info("Reading %s" % matrixFname)

    if matrixFname.endswith(".gz"):
        ifh = gzip.open(matrixFname)
    else:
        ifh = open(matrixFname)

    cellNames = ifh.readline().rstrip("\n").split("\t")[1:]
    for line in ifh:
        fs = line.rstrip("\n").split("\t")
        geneId = fs[0]
        if "_" in geneId:
            geneId = geneId.split("_")[1]
        for cellName, count in zip(cellNames, fs[1:]):
            count = float(count)
            if count!=0.0:
                cellTransCounts[cellName][geneId] = float(count)
    cellTransCounts = dict(cellTransCounts)

    logging.info("Writing %s" % marshFname)
    marshFname = matrixFname+".marshal"
    exprMat = marshal.dump(cellTransCounts, open(marshFname, "w"))

def parseMatrix(matrixFname):
    """ 
    input: tab-sep file with one gene per row and one cellId per column
    returns: matrix as dict cellName -> transcriptId or geneId -> count 
    side effect: creates <matrixFname>.marshal which is faster to load when called next time.
    """
    marshFname = matrixFname+".marshal"
    #if not isfile(marshFname) or options.force:
    if not isfile(marshFname):
        convTabToMarshal(matrixFname)

    logging.info("Reading %s" % marshFname)
    exprMat = marshal.load(open(marshFname))
    return exprMat

def writeCellTypeJson(cellToType, cellTypeFname):
    " write file cellTypes.json to outDir "
    counts = Counter(cellToType.values())
    countDicts = []
    for cellType, count in counts.most_common():
        countDicts.append({"value":count, "label":cellType})
    ofh = open(cellTypeFname, "w")
    json.dump(countDicts, ofh)
    logging.info("Wrote %s" % ofh.name)

def writeExprTransHistJson(exprMatrix, outFname):
    """
    create a json with the histogram of the number of expressed transcripts
    """
    # bin data
    binCounts = defaultdict(int)
    for cellName, transCounts in exprMatrix.iteritems():
        binCounts[1000*(len(transCounts) / 1000)] += 1

    points = []
    for transCount, val in binCounts.iteritems():
        points.append( (transCount, val) )
    points.sort()

    # reformat for D3
    jsonArr = []
    for x, y in points:
        jsonArr.append({"x":x, "y":y})

    # dump as json
    json.dump(jsonArr, open(outFname, "w"))
    logging.info("Wrote %s" % outFname)

def writeTransCountSection(htmlFh, jsonFname):
    """
    write html page section with a histogram of the expressed transcripts
    per cell
    """
    jsonData = open(jsonFname).read()
    htmlFh.write("""
<h2>Gene Expression</h2>
<div id="transHistChart"><svg style="width:500px; height:300px"></svg></div>
<script>
var transCountData = [{key: "number of samples", values : %s }];

nv.addGraph(function() {
  var transHistChart = nv.models.multiBarChart()
      .reduceXTicks(false)   //If 'false', every single x-axis tick label will be rendered.
      .rotateLabels(90)      //Angle to rotate x-axis labels.
      .showControls(false)   //Allow user to switch between 'Grouped' and 'Stacked' mode.
      .showLegend(false)
      .groupSpacing(0)    //Distance between each group of bars.
    ;
    
    transHistChart.xAxis.axisLabel("number of expressed genes");
    transHistChart.yAxis.axisLabel("number of samples");

    transHistChart.xAxis
        .tickFormat(d3.format(',f'));

    transHistChart.yAxis
        .tickFormat(d3.format(',f'));

    d3.select("#transHistChart svg")
        .datum(transCountData)
        .call(transHistChart);

    nv.utils.windowResize(transHistChart.update);

  return transHistChart;
});

</script>
""" % jsonData)

def stripTag(line, tag):
    " remove all html tags from a line "
    remHtml = re.compile("<(.|\n)*?>")
    line = re.sub(remHtml, "", line)
    return line

def getAbstract(pmid):
    """ retrieve abstracts for pmid using NCBI Entrez API"""

    logging.info("Sending NCBI entrez request for PMID %s" % pmid)
    url = '''https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=pubmed&email=mhaeussl@ucsc.edu&retmode=xml&id=%s''' % pmid
    req = urllib2.Request(url)
    try:
        html = urllib2.urlopen(req)
    except urllib2.HTTPError:
        raise Exception("Could not get article info for PMID %s" % pmid)
        #return None
    
    journal, title, year= "","",""
    abstract=""
    lastNames = []
    firstNames = []
    inDate = False
    authors=""
    doi=""

    for line in html:
        line = line.strip()
        if line.find("<Title>")!=-1:
            journal = stripTag(line, "Title")
        if line.find("<ISOAbbreviation>")!=-1:
            journalShort = stripTag(line, "ISOAbbreviation")
        if line.find("<ArticleTitle>")!=-1:
            title = stripTag(line, "ArticleTitle")
        if line.find("<LastName>")!=-1:
            lastNames.append(stripTag(line, "LastName"))
        if line.find("<FirstName>")!=-1 or line.find("ForeName")!=-1:
            firstNames.append(stripTag(stripTag(line, "FirstName"), "ForeName"))
        if line.find("ArticleDate")!=-1 or line.find("PubDate")!=-1:
            inDate=True
        if line.find("<Year>")!=-1 and inDate:
            year=stripTag(line, "Year")
            inDate=False
        if line.find('ArticleId IdType="doi"')!=-1:
            doi=stripTag(line, "ArticleId")
        if line.find("<AbstractText>")!=-1:
            abstract=stripTag(line, "AbstractText")
        authors = [last+" "+first for first, last in zip(firstNames, lastNames)]
        authors = "; ".join(authors)
    artData = {"year":year, "journal":journal, "authors":authors, \
        "title":title, "abstract":abstract, "doi":doi, "journalShort":journalShort}
    return artData
        
def addCdwDatasetInfo(datasetId, sumData):
    " get info from the table cdw.cdwDataset and add to sumData "
    # temp file will be removed at the end of the function automatically
    tempFh = tempfile.NamedTemporaryFile()
    cmd = '''hgsql cdw -NBe 'select name, label, description, pmid from cdwDataset where name="%s"' > %s''' % (datasetId, tempFh.name)
    os.system(cmd)
    row = tempFh.readline().rstrip("\n").split("\t")
    if len(row)==1:
        logging.warn("No dataset info in mysql table cdwDataset found")
        return sumData

    name, label, description, pmid = row[:4]
    sumData["label"] = label
    sumData["description"] = description
    sumData["pmid"] = pmid
    return sumData

def extractRelevantTags(rowDict):
    """ return a dict with all values in rowDict for keys in the global cellInfoTags """
    #strList = []
    #for t in fields:
        #strList.append(rowDict.get(t, ""))
    infoDict = {}
    rowDict = dict(rowDict)

    #if options.showAllTags:
    return rowDict

    #for extTag, intTag, desc in cellInfoTags:
        #infoDict[intTag] = rowDict.get(extTag, "")
    #return infoDict
    
def addMetaInfo(metaTabFname, sumData):
    """
    parse the meta tag info from the meta.tab file
    summarize the data and add it to the sumData dict
    as tagName (internal) -> list of (value, count) [sorted by count]
    Also create a dict sumData["cellInfo"] -> cellMetaId -> cellDesc (dict)
    and sumData["cellInfoTags"]
    """
    # fill a nested dict with tagName -> value -> count
    cellInfo = {}
    valCounts = defaultdict(Counter)
    for row in lineFileNextRow(metaTabFname):
        rowDict = row._asdict()
        for key, val in rowDict.iteritems():
            valCounts[key][val]+=1
        cellId = rowDict[options.meta]
        cellInfo[cellId] = extractRelevantTags(rowDict)

    # keep a few fields literally as a big cellId -> dict of values
    sumData["cellInfo"] = cellInfo

    # add some of the counts to the sumData dict
    for metaTag, myTag, desc in summaryTags:
        sortVals = valCounts[metaTag].most_common()
        if len(sortVals)==0:
            continue
        if myTag in metaFieldsOnlyOne:
            sumData[myTag] = sortVals[0][0]
        else:
            sumData[myTag] = sortVals

    if "title" not in sumData:
       sumData["title"] = "Dataset %s" % sumData["datasetId"]
    #if "summary" not in sumData:
       #sumData["summary"] = "no dataset summary in meta data" 
    #if "submission_date" not in sumData:
       #sumData["submission_date"] = None

    return sumData
        
def addPubmedInfo(pmid, sumData, outDir):
    """ get article info from ncbi entrez for pmid
    author, title, doi etc """
    if pmid!="":
        artFname = join(outDir, "pubmed.tmp")
        if isfile(artFname):
            logging.info("Reading article data from %s" % artFname)
            artData = json.load(open(artFname))
        else:
            artData = getAbstract(pmid)
            json.dump(artData, open(artFname, "w"))
            logging.info("Wrote article data to %s" % artFname)
        for artKey, artVal in artData.iteritems():
            sumData["art_"+artKey] = artVal
    sumData["art_pmid"] = pmid
    return sumData

def addMarkerGenes(geneToTypes, exprMatrix, sumData):
    """
    write the expr counts for all marker genes into the sumData
    dict as a matrix , not as a dict
    format:
    sumData["markers"] = list of gene names
    sumData["exprRows"] -> list of (exprValCell0, exprValCell1,...)
    exprRows contains one tuple per marker gene
    """
    geneSet = set()
    for cellId, exprDict in exprMatrix.iteritems():
        geneSet.update(exprDict)

    markers = sorted(geneSet.intersection(geneToTypes))
    sumData["markers"] = geneToTypes

    cellIds = sumData["cellIds"]
    exprRows = []
    for markerId in geneToTypes:
        exprRow = []
        for cellId in cellIds:
            exprRow.append(exprMatrix[cellId].get(markerId, 0.0))
        exprRows.append(exprRow)
    sumData["exprRows"] = exprRows

    return sumData

def addAnalysisInfo(fname, sumData):
    """ an additional meta.tab file with information supplied by a pipeline tool
    At the moment, this is kallistoInReadCount, kallistoAlnReadCount and 
    kallistoEstFragLen
    """
    if fname==None:
        return sumData
    if not isfile(fname):
        logging.info("Could not find %s" % fname)
        return sumData

    logging.info("Reading additional analysis info file %s" % fname)
    toolMeta = []
    for row in lineFileNextRow(fname):
        sumData["toolMetaFields"] = row._fields
        toolMeta.append(tuple(row))

    sumData["toolMeta"] = toolMeta
    return sumData

def getCommonCellIds(matrixCellIds, metaCellIds):
    """ return cellIds that are common between the matrix and the meta info and inform the user about this important difference.
    Abort if there is no overlap between meta and matrix. 
    """
    if len(matrixCellIds)==0:
        return []

    matrixCellIds = set(matrixCellIds)
    metaCellIds = set(metaCellIds)

    excessMatrix = matrixCellIds - metaCellIds
    excessMeta = metaCellIds - matrixCellIds
    commonIds = matrixCellIds.intersection(metaCellIds)
    
    if len(commonIds)==0:
        raise Exception("No overlap between cellIds from meta data and cellIds from the matrix")

    if len(excessMatrix)!=0:
        logging.warn("The matrix contains some cell identifiers that are not in the meta data: %s" % ",".join(excessMatrix))
        logging.warn("Using only %d indentifiers that are in both the matrix and the meta data." % len(commonIds))

    if len(excessMeta)!=0:
        logging.warn("The meta info contains some cell identifiers that are not in the matrix: %s" % ",".join(excessMeta))
        logging.warn("Using only %d indentifiers that are in both the matrix and the meta data." % len(commonIds))

    return sorted(commonIds)

def writeCacheFile(datasetId, exprMatrix, metaTabFname, addMetaTabFname, geneToTypes, outDir, cacheFname):
    """
    write a file with basic info about this project.
    Info is collected from matrix, tags and mysql.
    We do this so we don't have to parse all the big matrices each time
    when we want to change only a little html tag during development
    """
    sumData = {}
    if exprMatrix is not None:
        sumData["cellCount"] = len(exprMatrix)
    else:
        sumData["cellCount"] = 0
    sumData["datasetId"] = datasetId

    # get cdw dataset mysql table info
    sumData = addCdwDatasetInfo(datasetId, sumData)
    if metaTabFname!="none":
        sumData = addMetaInfo(metaTabFname, sumData)
    if metaTabFname!="none":
        sumData = addAnalysisInfo(addMetaTabFname, sumData)

    if exprMatrix is not None: # and "cellIds" in sumData:
        sumData["cellIds"] = getCommonCellIds(exprMatrix.keys(), sumData["cellInfo"].keys())

        sumData = addMarkerGenes(geneToTypes, exprMatrix, sumData)

    if "pmid" in sumData:
        sumData = addPubmedInfo(sumData["pmid"], sumData, outDir)

    json.dump(sumData, open(cacheFname, "w"))
    logging.info("Wrote %s" % cacheFname)
    return sumData

def generateSummarySection(summData, htmlFh):
    htmlFh.write("""
    <h2>%(title)s</h2>
    <div style="width:800px">""" % summData)

    if "submission_date" in summData:
        htmlFh.write("""<b>UCSC Dataset ID:</b> <A HREF=\"cdwWebBrowse?cdwCommand=browseFiles&cdwBrowseFiles_f_data_set_id=%(datasetId)s\">%(datasetId)s</A>&nbsp; &nbsp; <b>submitted:</b> %(submission_date)s &nbsp;<br>
    """ % summData)
    sampleCount = summData["cellCount"]
    if sampleCount!=0:
        htmlFh.write("<b>Sample count:</b> %d<br>" % sampleCount)

    skipTags = ['submission_date', 'title']

    for metaTag, myTag, desc in summaryTags:
        if myTag in skipTags:
            continue
        if myTag not in summData:
            logging.info("Could not find tag '%s' in summary meta data" % myTag)
            continue
        tagVals = summData[myTag]
        if len(tagVals)==0:
            continue
        if myTag in metaFieldsOnlyOne:
            metaHtml = tagVals
        else:
            strList = []
            for val, count in summData[myTag]:
                if myTag=="geoAccs":
                    val = '<a href="https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=%s">%s</a>' % (val, val)
                if len(tagVals)==1:
                    strList.append(val)
                else:
                    strList.append("%s (%d)" % (val, count))
            metaHtml = ", ".join(strList)
        htmlFh.write("<b>%s:</b> %s<br>\n" % (desc, metaHtml))

    if "firstAuthor" in summData:
        summData["firstAuthor"] = summData["art_authors"].split(" ")[0]
        htmlFh.write("""
        <br><b>Article:</b> %(firstAuthor)s et al: <a href="https://www.ncbi.nlm.nih.gov/pubmed/%(art_pmid)s">%(art_title)s</a><br>
        <b>Journal:</b> %(art_journalShort)s %(art_year)s<br>
        <b>Article Abstract:</b> <span class="collapse in">%(art_abstract)s</span><br>
        """ % summData)
    htmlFh.write("</div>\n")

def writeHeatmapHeader(ofh):
    " write html header for the heatmap page "
    writeHeader(ofh, "Heatmap")

    ofh.write("""
<style>
   th.rotate {
     white-space: nowrap;
     text-alignment: left;
     vertical-align: bottom;
   }
   
   th.rotate > div {
     float:left;
     white-space: nowrap;
     position: relative;
     border-style: none;
     transform: rotate(270deg);
     -webkit-transform: translateX(20px) rotate(-90deg);
     -moz-transform: translateX(20px) rotate(270deg);
     -ms-transform: translateX(20px) rotate(270deg);
     -o-transform: translateX(20px) rotate(270deg);

     transform-origin: bottom left;
     -webkit-transform-origin: bottom left;
     -moz-transform-origin: bottom left;
     -ms-transform-origin: bottom left;
     -o-transform-origin: bottom left;
   }

   th.rotate > div > span {
     /* border-bottom: 1px solid #ccc; */
     padding-bottom: 5px;
     white-space: nowrap;
   }

   table.heatTable {
    background:white;
    table-layout:fixed;
    width:100%;
    padding: 0px;
    border:1px;
    border-color: #CCCCC;
    margin: 0px;
    border-collapse:collapse
   }

   .hlRow {
    font-weight: bold;
    border:dashed black;
   }

   .hlCol {
    font-weight: bold;
    border-left:dashed black;
    border-right:dashed black;
   }


</style>

    <h2>Expression Heatmap</h2>
    Mouse-over a gene ID to show cell types for this marker.<br>
    Mouse-over a cell ID to show cell source and type.<br>
    Mouse-over any cell to show gene, cell-ID and expression TPM value. <p>
    """)

def reorderCells(jsonFname, cellIds, exprMatrix):
    """
    put the cellIds and expr. rows into the same order as in the D3 JSON file
    """
    # create a dict that maps cellId -> new position
    newCellIds  = jsonTreeToList(jsonFname)
    #idxToVal    = list(enumerate(newCellIds))
    #valToNewIdx = dict([(y,x) for x,y in idxToVal])
    #newOrder    = [valToNewIdx[i] for i in cellIds]
    newOrder    = [cellIds.index(i) for i in newCellIds]

    newCellIds = [cellIds[x] for x in newOrder]

    newMatrix = []
    for geneId, exprList in exprMatrix:
        newExprList = [exprList[x] for x in newOrder]
        newMatrix.append( (geneId, newExprList) )

    return newCellIds, newMatrix

def transMatrix(cellIds, exprMatrix):
    """
    encodes the matrix in the format cell - gene
    returns a list of rows, rows are cells, columns are genes
    """
    # make a list of genes the have no expression anywhere
    # important for PCA, as we need to remove genes that lead to 
    # 0 division when scaling the columns
    skipGenes = []
    validGenes = []
    for geneId, exprList in exprMatrix:
        if list(set(exprList))==[0.0]:
           skipGenes.append(geneId)
        else:
            validGenes.append(geneId)

    logging.info("These genes have 0-only data: %s" % ",".join(skipGenes))

    cellRows = []

    row1 = ["cellId"]
    row1.extend(validGenes)
    cellRows.append(row1)

    for cellIdx, cellId in enumerate(cellIds):
        newRow = [cellId]
        for geneId, exprList in exprMatrix:
            if geneId in skipGenes:
                continue
            newRow.append(exprList[cellIdx])
        #newRow.extend([exprList[cellIdx] for geneId, exprList in exprMatrix if geneId not in skipGenes])
        cellRows.append(newRow)

    return cellRows

def writeHeatmapPage(outDir, sumData):
    """
    write the file heatMap.html based on the exprRows with geneIds in 
    rows and cellIds in columns
    """

    markerToCellType = sumData["markers"]
    cellIds = sumData["cellIds"]
    exprRows = sumData["exprRows"]
    cellInfo = sumData["cellInfo"]

    htmlFname = join(outDir, "heatmap.html")
    ofh = open(htmlFname, "w")
    writeHeatmapHeader(ofh)

    ofh.write("""
    <table class="heatTable">
    <thead>
    <tr style="background-color:#F0F0F0; vertical-align: bottom; text-align: left">
    <th style="width:90px; height:100px;">Gene</th>
    """)

    exprMatrix = zip(markerToCellType, exprRows)
    if options.cellTreeJson!="" and options.cellTreeJson is not None:
        cellIds, exprMatrix = reorderCells(options.cellTreeJson, cellIds, exprMatrix)

    for cellId in cellIds:
        if cellId not in cellInfo:
            logging.warning("no meta info for cell %s" % cellId)
            continue
        mouseOver = ", ".join(cellInfo[cellId].values())
        ofh.write('<th id="%s" style="width:20px" class="rotate" title="%s"><div><span>%s</div></span></th>\n' % (cellId, mouseOver, cellId))

    ofh.write("""</tr></thead>

    <tbody>
    """)

    # created with:
    # $ pip install colour
    # $ python
    # from colour import Color
    # print [h.hex for h in list(Color("red").range_to(Color("blue")))]
    colors = list(reversed(['#f00', '#ff7100', '#ffe300', '#af0', '#39ff00', '#00ff39', '#0fa', '#00e3ff', '#0071ff', '#00f']))

    # the thresholds to assign colors to the cells
    colBins = [0, 0.5, 1, 10, 50, 100, 1000, 10000, 100000, 1000000]

    for markerId, exprRow in exprMatrix:
        ofh.write("<tr id='%s'>\n" % markerId)
        ofh.write("<td title='%s'>%s</td>\n" % \
                (",".join(markerToCellType[markerId]), markerId))
        for cellId, exprVal in zip(cellIds, exprRow):
            color = None
            for colIdx in range(len(colBins)):
                if exprVal < colBins[colIdx+1]:
                    break
            color = colors[colIdx]
            ofh.write("<td id='%s' title='%s %s TPM=%0.2f' style='background-color:%s'></td>" % \
                (cellId, markerId, cellId, exprVal, color))
            #ofh.write("<td>%0.2f</td> " % exprVal)
        ofh.write("</tr>\n")

    ofh.write("""</tbody></table>""")
    ofh.write("""
<script>
$(document).ready(doHl);
$(window).bind('hashchange', doHl);

function doHl() {
    /* parse the row and column to highlight from current URL and set CSS class */
    var url = window.location.href;
    var urlParts = url.split('#');
    if ((urlParts.length)>1) {
        var anchorParts  = urlParts[1].split("-");
        var rowId = anchorParts[0];
        var colId = anchorParts[1];
        if (rowId!='')
            $('#'+rowId).addClass('hlRow');
        if (colId!='') {
            $('td[id='+colId+']').addClass('hlCol');
            $('th[id='+colId+']').addClass('hlCol');
            $.scrollTo("#"+colId);
        }
    }
}
</script>
    """)
    ofh.write("""</body></html>""")

    return cellIds, exprMatrix

def prepCellRows(toolFields, toolRows, cellInfo):
    """ merge the cell info and tool meta info dicts to a normal table, calculate a few QC-measures """

    outRows = []
    for row in toolRows:
        rowDict = {}
        cellId = row[0]
        rowDict["cellId"]=cellId

        # add the toolMeta.tab info
        for fieldName, val in zip(toolFields, row):
            rowDict[fieldName] = val

        # merge in the cell info from the meta.tab file
        #cellInfoTuple = cellInfo[cellId]
        for infoTag, infoVal in cellInfo[cellId].iteritems():
            rowDict[infoTag] = infoVal

        if "kallistoAlnReads" in rowDict:
            mapShare = 100*float(rowDict["kallistoAlnReads"])/float(rowDict["kallistoProcReads"])
            #numStr = "%dk" % (int(round(float(rowDict["kallistoProcReads"])/1000)))
            #rowDict["alnShare"] = "%d %% of %s" % (mapShare, numStr)
            rowDict["alnShare"] = "%d %%" % (mapShare)
            #rowDict["readTotal"] = "%d %%" % (mapShare)
            rowDict["kallistoEstFragLen"] = str(int(float(rowDict["kallistoEstFragLen"])))
            for tag in ["kallistoAlnReads", "kallistoProcReads"]:
                numStr = "%d" % (int(round(float(rowDict[tag])/1000)))
                rowDict[tag] = numStr
        else:
            rowDict["alnShare"] = "NA"
            rowDict["kallistoEstFragLen"] = "NA"
            for tag in ["kallistoAlnReads", "kallistoProcReads"]:
                rowDict[tag] = "NA"

        outRows.append(rowDict)
    return outRows

def writeCellPage(outDir, sumData):
    """
    write a separate page with information about the sequencing info
    for each cell
    """
    toolInfo = sumData.get("toolMeta", [])
    toolFields = sumData.get("toolMetaFields", [])
    cellInfo = sumData["cellInfo"]

    outRows = prepCellRows(toolFields, toolInfo, cellInfo)
    outRows.sort()

    fname = join(outDir, "cellInfo.html")
    ofh = open(fname, "w")
    writeHeader(ofh, "Single Cell Details")
    ofh.write("<p><h2>Single Cell Details</h2>\n")

    #fieldDescs = [
    #("life_stage" , "Life stage"),
    #("sampleAge" , "Age"),
    #("body_part" , "Body Part") ,
    #("sampleName" , "Cell ID"),
    #("cellId" , "SRA<br>Run ID"),
    #("chipId", "Fluidigm<br>Chip ID"),
    #("sequencer" , "Sequencer"),
    #("kallistoProcReads" , "Sequenced<br>Reads"),
    #("kallistoAlnReads" , "Aligned<br>Reads"),
    #("kallistoProcReads" , "Total K<br>Reads"),
    #("alnShare" , "Alignable<br>Reads"),
    #("kallistoEstFragLen" , "Est. Frag.<br>Length"),
    #("cell_type" , "Submitted Cell<br>Type")
    #]

    ofh.write("""
<style>
   table {
    background:#D9F8E4;
    border-collapse:collapse
   }

   thead tr {
     background-color:#1616D1;
     color:#FFFFFF;
     vertical-align: top;
     text-align: left
   }

   table, th, tr, td {
    border: 1px solid black;
    padding: 5px;
   }
</style>
    """)

    ofh.write("""<table class='sortable'>
<thead>
<tr>
    """)

    #for fieldName, fieldDesc in fieldDescs:
        #ofh.write('<th>%s</th>\n' % \
            #fieldDesc)

    ofh.write("</tr></thead>\n")
    ofh.write("<tbody>\n")

    doSmall = ["GEO_Sample_age", "body_part"]

    for rowDict in outRows:
        ofh.write("<tr>\n  ")
        for fieldName in rowDict:
            val = rowDict[fieldName]
            if fieldName=="sequencer":
                val = val.replace("Illumina ", "")
            if fieldName in doSmall:
                ofh.write("<td style='line-height:0.7em'><small>%s</small></td>" % val)
            elif fieldName=="cellId":
                ofh.write("<td><a href='heatmap.html#-%s'>%s</a></td>" % (val, val))
            else:
                ofh.write("<td>%s</td>" % val)
        ofh.write("</tr>\n")
        
    ofh.write("</tbody></table></body></html>\n")

    ofh.close()

def writeMatrix(rows, fname, headers=None):
    """ write a list of rows to tab sep file """
    ofh = open(fname, "w")
    if headers:
        ofh.write("\t".join(headers))
        ofh.write("\n")

    for row in rows:
        row = [str(x) for x in row]
        ofh.write("\t".join(row))
        ofh.write("\n")
    ofh.close()
    logging.info("Wrote %s" % fname)

def runRscript(scriptStr, scriptFname):
    """ run an R script given as a string through Rscript """
    scriptFh = open(scriptFname, "w")
    scriptFh.write(scriptStr)
    scriptFh.close()
    logging.info("Wrote %s, running through Rscript" % scriptFh.name)

    cmd = "time Rscript %s" % scriptFh.name
    assert(os.system(cmd)==0)

def calcPca(matrixFname, outFname, tmpDir):
    " run a matrix through the R PCA function prcomp and write the princ. components to outFname "
    scriptPath = join(tmpDir, "runPca.R")

    pcaScript = """data = read.table("%s", header=TRUE, row.names=1)
data = log(data+1)
pca = prcomp(data, scale.=TRUE, center=TRUE)
write.table(pca$x, "%s", sep="\t", quote=FALSE, col.names=NA)
""" % (matrixFname, outFname)

    runRscript(pcaScript, scriptPath)
    
def parsePca(pcaFname):
    """ parse the scores from a R PCA file and return a dict cellId -> 
    first five principal components of the PCA 
    """
    logging.info("Parsing %s to generate Principal Components plot" % pcaFname)
    data = {}
    for row in lineFileNextRow(pcaFname):
        data[row.rowName] = ( float(row.PC1), float(row.PC2), float(row.PC3), float(row.PC4), float(row.PC5) )
    return data

#def writePcaPlot(htmlFh):
    #""" write D3 code to plot the PCA results """
    #htmlFh.write("""
#""")

def writePcaSection(htmlFh, matrixFname, exprMatrix, markerGenes, cellInfo, outDir):
    " add the PCA plot to the html page "
    htmlFh.write("<h2>Principal Components Analysis</h2>\n")

    tmpDir = join(outDir, "tmp")
    pcaFname = join(tmpDir, "pcaResult.tab")
    if isfile(pcaFname) and not options.force:
        logging.info("Not re-running R, cache file %s still exists" % pcaFname)
    else:
        if not exprMatrix:
            exprMatrix = parseMatrix(matrixFname)
        if markerGenes is not None:
            markerGenes = set(markerGenes)
        pcaMatrix = dictToCellMatrix(exprMatrix, markerGenes)
        pcaMatrixFname = join(tmpDir, "pcaInMatrix.tab")
        writeMatrix(pcaMatrix, pcaMatrixFname)
        calcPca(pcaMatrixFname, pcaFname, tmpDir)

    pcaCoords = parsePca(pcaFname)

    # prepare PCs as a list of javascript maps
    pcaData = []
    for cellId, pcs in pcaCoords.iteritems():
        if cellId not in cellInfo:
            continue
        pc1, pc2, pc3, pc4, pc5 = pcs
        cellData = { "pc1":pc1, "pc2":pc2, "pc3":pc3, "pc4":pc4, "pc5":pc5, "cellId":cellId }
        pcaData.append(cellData)
    pcaDataJson = json.dumps(pcaData)

    htmlFh.write("""

<div id="pcaChartDiv"><svg style="width:800px;height:400px"></svg>
    <div style="margin-left: 100px">
    <form action="no">
    X-Axis: 
    <select name="no" id="xAxisSelect">
    <option value="pc1" selected="selected">PC1</option>
    <option value="pc2">PC2</option>
    <option value="pc3">PC3</option>
    <option value="pc4">PC4</option>
    <option value="pc5">PC5</option>
    </select>

    Y-Axis: 
    <select name="no" id="yAxisSelect">
    <option value="pc1">PC1</option>
    <option value="pc2" selected="selected">PC2</option>
    <option value="pc3">PC3</option>
    <option value="pc4">PC4</option>
    <option value="pc5">PC5</option>
    </select>

    &nbsp&nbsp&nbspColor By: 
    <select name="no" id="pcaGroupSelect">
""")

    if not options.showAllTags:
        for extName, intName, label in sorted(cellInfoTags):
            if (intName == "meta"): continue
            htmlFh.write("   <option NAME='no' value='%s'>%s</option>\n" % (intName, label))
    else:
        for tagName, _ in sorted(cellInfo.values()[0].iteritems()):
            if (tagName == "meta"): continue
            htmlFh.write("   <option NAME='no' value='%s'>%s</option>\n" % (tagName, tagName))

    #for extName, intName, label in cellInfoTags:
        #htmlFh.write("   <option value='%s'>%s</option>\n" % (intName, label))

    htmlFh.write("""
    </select>
    </form>
    </div>
</div>

<script>
var pcaChart = null;
var pcaData = %(pcaDataJson)s;

$("#xAxisSelect").val("pc1");
$("#yAxisSelect").val("pc2");
addPcaGraph("pc1", "pc2", $("#pcaGroupSelect").val());

// set the dropdowns to default values and attach listeners
$("#xAxisSelect").change(addPcaGraph);
$("#yAxisSelect").change(addPcaGraph);
$("#pcaGroupSelect").change(addPcaGraph);

</script> """ % locals())
#(json.dumps(pcaData), cellInfoLabelJson), cellInfoJson);
        
    #logging.info("Wrote %s" % csvFname)
    
    #writePcaPlot(htmlFh)

def getNonZeroGenes(exprMatrix, onlyGenes):
    " given a matrix with cellId -> geneId -> float, return the sorted list of geneIds that have some non-0 data "
    # first make a list of all genes, remove those that are always 0
    logging.debug("Getting all non-zero genes from matrix")
    allGenes = set()
    notZeroGenes = set()
    for cellName, geneDict in exprMatrix.iteritems():
        allGenes.update(geneDict)
        for gene, val in geneDict.iteritems():
            if val!=0.0:
                notZeroGenes.add(gene)

    skipGenes = allGenes - notZeroGenes
    if len(skipGenes)!=0:
        logging.warn("These genes have only 0 TPM values across all cells: %s" % ",".join(skipGenes))
    allGenes = allGenes - skipGenes

    #print "allGenes", allGenes
    #print "onlyGenes", onlyGenes
    if onlyGenes is not None:
        logging.info("Keeping only %d genes: e.g. %s..." % (len(list(onlyGenes)), list(onlyGenes)[:5]))
        allGenes = allGenes.intersection(onlyGenes)
    #print "after", allGenes

    allGenes = list(allGenes)
    allGenes.sort()
    logging.debug("Found %d non-zero genes in matrix" % len(allGenes))
    assert(len(allGenes)!=0) # no genes left to write. Probably wrong gene Ids in -s
    return allGenes

def dictToGeneMatrix(exprMatrix, onlyGenes):
    """ input: dict cellId -> transcriptId -> count 
        returns: list of rows, 1st row has headers, 1st col has geneId
        The output is a matrix with one gene per row
    """
    allGenes = getNonZeroGenes(exprMatrix, onlyGenes)

    cellIds = sorted(exprMatrix)
    # now create a matrix where each line is a cell, each column is a gene
    headRow = ["geneId"]
    headRow.extend(cellIds)

    rows = []
    rows.append(headRow)

    i=0
    for geneId in allGenes:
        newRow = [geneId]
        for cellId in cellIds:
            val = exprMatrix[cellId].get(geneId, 0)
            newRow.append(val)
        rows.append(newRow)
        i += 1

    return rows

def dictToCellMatrix(exprMatrix, onlyGenes):
    """ input: dict cellId -> transcriptId -> count 
        ouput: list of rows, 1st row has headers, 1st col has cellId
        The output is a matrix with one cell per row.
    """
    allGenes = getNonZeroGenes(exprMatrix, onlyGenes)
    # now create a matrix where each line is a cell, each column is a gene
    rows = []
    row1 = ["cellId"]
    row1.extend(allGenes)
    rows.append(row1)

    i=0
    for cellName, geneDict in exprMatrix.iteritems():
        newRow = [cellName]
        for gene in allGenes:
            val = geneDict.get(gene, 0)
            newRow.append(val)
        assert(len(newRow)!=1) # no marker gene found at all
        rows.append(newRow)
        i += 1

    return rows

def rewriteMatrixForSeurat(matrixFname, exprMatrix, cellInfo, geneToSym, seuratMatFname):
    " rewrite the headers for seurat, it accepts only one format "
    if isfile(seuratMatFname):
        logging.info("Not exporting matrix for Seurat, %s already exists" % seuratMatFname)
        return

    if not exprMatrix:
        exprMatrix = parseMatrix(matrixFname)

    logging.info("Writing matrix for Seurat to %s" % seuratMatFname)
    seuratMatrix = dictToGeneMatrix(exprMatrix, onlyGenes=geneToSym)
    newMatrix = []
    # make the headers conform to the seurat input format Hi_<cellType>_<cellId>
    newHeaders = ["Gene_Symbol"]
    for cellId in seuratMatrix[0][1:]:
        #cellType = cellInfo[cellId].get("cell_type", "noCellTypeFound")
        #if cellType=="":
            #cellType="NA"
        #cellType = cellType.replace(" ", "-")
        #cellType = cellType.replace("_", "-")
        cellId = encodeCellId(cellId)
        newHeaders.append("Hi_NA_%s" % (cellId))
    
    writeMatrix(seuratMatrix[1:], seuratMatFname, headers = newHeaders)

    #ofh = open(seuratMatFname, "w")
    #i = 0
    #for line in open(matrixFname):
    #    if i == 0:
    #        line = line.rstrip("\n")
    #        headers = line.split("\t")
    #        newHeaders = ["Gene_Symbol"]
    #        for cellId in headers[1:]:
    #            cellType = cellInfo[cellId].get("cell_type", "noCellTypeFound")
    #            cellType = cellType.replace(" ", "-")
    #            cellType = cellType.replace("_", "-")
    #            cellId = encodeCellId(cellId)
    #            if cellType=="":
    #                cellType="NA"
    #            newHeaders.append("Hi_%s_%s" % (cellType, cellId))
    #        ofh.write("\t".join(newHeaders)+"\n")
    #    else:
    #        geneId, restLine = string.split(line, "\t", 1)
    #        if geneId in geneToSym:
    #            geneId = geneId+"/"+geneToSym[geneId]
    #        # by default, output only genes with symbols, unless special flag is set.
    #        if seuratAllGenes or ("/" in geneId):
    #            ofh.write(geneId+"\t"+restLine)
    #        #if i>100:
    #            #break
    #    i += 1
    #ofh.close()


def makeSeuratScript(seuratMatFname, tsnePath, clusterPath, clusterGenesPath):
    " return a string with the seurat build script "
    seuratScript = """
library(methods)
suppressWarnings(suppressMessages(library(Seurat)))
print("Seurat: Reading data")
nbt.data=read.table("%(seuratMatFname)s",sep="\t",header=TRUE,row.names=1)
#nbt.data=log(nbt.data+1)
nbt=new("seurat",raw.data=nbt.data)
print("Seurat: Setup")
#nbt=setup(nbt,project="NBT",min.cells = 3,names.field = 2,names.delim = "_",min.genes = 500,is.expr=1,)
nbt=Setup(nbt,project = "NBT",min.cells = 3,names.field = 2,names.delim = "_",min.genes = 500, do.logNormalize = T, total.expr = 1e4)
print("Seurat: Mean Variant")
print("CIRM Warning: We are not regressing out cell cycle or mitochondrial signal yet, see http://satijalab.org/seurat/pbmc-tutorial.html")
nbt=MeanVarPlot(nbt,y.cutoff = 2,x.low.cutoff = 2,fxn.x = expMean,fxn.y = logVarDivMean)
print("Seurat: PCA")
nbt=PCA(nbt,do.print=FALSE)
print("Seurat: JackStraw")
# max: changed num.replicate from 200 -> 100
nbt=JackStraw(nbt,num.replicate = 100,do.print = FALSE) 

print("Seurat: PCA Projection")
nbt=ProjectPCA(nbt,do.print=FALSE)

print("Seurat: Significant genes from PCA")
pcCount=min(ncol(nbt@pca.rot), 9)
nbt.sig.genes=PCASigGenes(nbt,1:pcCount,pval.cut = 1e-5,max.per.pc = 200)
print("Seurat: PCA 2 using sign. genes from PCA 1")
nbt=PCA(nbt,pc.genes=nbt.sig.genes,do.print = FALSE)

print("Seurat: JackStraw 2 using all sign. genes")
# max: changed num.replicate from 200 -> 100
nbt=JackStraw(nbt,num.replicate = 100,do.print = FALSE)
print("Seurat: Dimensionality Reduction")
# max: I have decreased the number of iterations from 2000 in the example to 1500
nbt=RunTSNE(nbt,dims.use = 1:11,max_iter=1500)

tsne12 <- FetchData(nbt, c("tSNE_1", "tSNE_2"))
write.table(tsne12, "%(tsnePath)s", quote=FALSE, sep="\t", col.names=NA)

print("Seurat: Cluster the tSNE data")
clusterCount=8
nbt=DBClustDimension(nbt,1,2,reduction.use = "tsne",G.use = clusterCount,set.ident = TRUE)
write.table(FetchData(nbt,c("ident")), "%(clusterPath)s", quote=FALSE, sep="\t", col.names=NA)

#for (i in 0:(clusterCount-1)) {
#    print(paste("Seurat: Writing specific genes for cluster", i))
#    write.table(find.markers(nbt,7,thresh.use = 3), paste0((clusterGenesBase)s, i, ".tab"))
#}

# find_all_markers crashes if there is only a single cluster
if (length(levels(nbt@ident))!=1) {
    markers.all=FindAllMarkers(nbt,test.use = "roc", do.print = TRUE)
    write.table(markers.all, "%(clusterGenesPath)s", quote=FALSE, sep="\t", col.names=NA)
} else {
    file.create("%(clusterGenesPath)s"); # create empty file to signal that all is OK.
}

#print("Seurat: Plot")
#png("(seuratPngPath)s", width=800, height=600)
#tsne.plot(nbt,pt.size = 1)
#dev.off()
""" % locals()
    return seuratScript

def encodeCellId(cellId):
    " replace invalid chars (for seurat) in cellId "
    assert("|" not in cellId)
    assert(" " not in cellId)
    #cellId = filter(str.isalnum, cellId) # keep only alphanum chhars
    #randomStr = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
    #cellId = cellId+"_"+randomStr
    cellId = cellId.replace("_",".")
    cellId = cellId.replace("-",".")
    cellId = cellId.replace("+",".")
    cellId = cellId.replace("=",".")
    cellId = cellId.replace("*",".")
    return cellId

def sanitizeForR(identList):
    " R does not accept many identifiers. Sanitize a list identifiers and return a map saneId -> original ID "
    newToOld = {}
    for origId in identList:
        newId = encodeCellId(origId)
        assert(newId not in newToOld) # ID must not appear twice after mapping
        newToOld[newId] = origId
    return newToOld

def parseTsneToJson(seuratToOrig, tsnePath):
    """ read the seurat tsne output and return a JSON string with a list of maps
    with attributes x, y, and cellId
    """
    coordList = []

    tsneData = []
    notFoundIds = set()
    for row in lineFileNextRow(tsnePath):
        seuratId = row.rowName.split("_")[-1]
        if seuratId not in seuratToOrig:
            notFoundIds.add(seuratId)
            continue
        cellId = seuratToOrig[seuratId]
        x = float(row.tSNE_1)
        y = float(row.tSNE_2)
        tsneData.append( {"x" : x, "y" : y, "cellId" : cellId } )

        coordList.append( (x, y, cellId) )

    if len(notFoundIds)!=0:
        logging.warn("%d cellIds were not found after seurat mapping: %s" % (len(notFoundIds), ",".join(notFoundIds)))
    return json.dumps(tsneData), coordList

def parseClusters(newToOrigId, clusterPath):
    """ convert seurat tsne cluster assignment to JSON string """
    cellToCluster = {}
    for row in lineFileNextRow(clusterPath):
        cellId = row.rowName.split("_")[-1]
        if cellId not in newToOrigId:
            continue
        cellId = newToOrigId[cellId]
        cellToCluster[cellId] = int(row.ident)
    return json.dumps(cellToCluster)

def readColorsJson(inFname):
    """ parse file with three columns: attribute, value, color
    return as JSON: nested dict attribute -> value -> color
    """
    if inFname==None:
        return "{}";
    ret = defaultdict(dict)
    for row in lineFileNextRow(inFname):
        attrName, attrVal, color = row
        ret[attrName][attrVal] = color
    return json.dumps(ret)

def writeSeuratSection(htmlFh, matrixFname, exprMatrix, cellInfo, geneToSym, tagLabels, outDir, cellInfoTagToLabel, jsonOutFname):
    """ run matrix through Seurat and add a section to html for it """
    tmpDir = makeTmpDirFor(outDir)
    seuratMatFname = join(tmpDir, "seuratInMatrix.tab")

    rewriteMatrixForSeurat(matrixFname, exprMatrix, cellInfo, geneToSym, seuratMatFname)

    #seuratPngPath = join(outDir, "seuratTsne.png")
    seuratScriptPath = join(tmpDir, "runSeurat.R")
    tsnePath = join(tmpDir, "seuratTsne.tab")
    clusterPath = join(tmpDir, "seuratClusters.tab")
    clusterGenesPath = join(tmpDir, "seuratClusterGenes.tab")
    seuratScript = makeSeuratScript(seuratMatFname, tsnePath, clusterPath, clusterGenesPath)

    tsneColorsJson = readColorsJson(options.colors)

    if not isfile(clusterGenesPath) or options.force:
        runRscript(seuratScript, seuratScriptPath)
    else:
        logging.info("Not running %s, %s already exists" % (seuratScriptPath, clusterGenesPath))

    newToOrig = sanitizeForR(cellInfo)
    tsneJson, coordList  = parseTsneToJson(newToOrig, tsnePath)

    clusterJson = parseClusters(newToOrig, clusterPath)

    htmlFh.write("""<h2>Seurat T-SNE: gene list trimming and non-linear dimensional reduction</h2>""")
    htmlFh.write("""
<div id="seuratChartDiv"><svg style="width:800px;height:%dpx"></svg>
    <div style="margin-left: 100px">
    Color By: 
    <form action="no">
    <select name="no" id="tsneGroupSelect">
""" % seuratHeight)

    colorBy = options.colorBy # pull global into local variable scope
    
    # make list of attributes for drop down box
    attrList = []
    if not options.showAllTags:
        for extName, intName, label in sorted(cellInfoTags):
            if (intName == "meta"): continue
            attrList.append((intName, label))
    else:
        for tagName, _ in sorted(cellInfo.values()[0].iteritems()):
            if (tagName == "meta"): continue
            tagLabel = tagLabels.get(tagName, tagName)
            attrList.append((tagName, tagLabel))

    # can always colors by seurat cluster id
    attrList.insert(0, ("tsneCluster", "Seurat Clustering"))

    for tagName, tagLabel in attrList:
        addStr = ""
        if tagName==colorBy:
            addStr = " selected='selected'"
        htmlFh.write(u'   <option value="%s"%s>%s</option>\n' % (tagName, addStr, tagLabel.encode('ascii', errors="ignore")))

    htmlFh.write("""
    </select>
    </form>
    </div>
</div>

<script>""")

    htmlFh.write("""
var seuratTsnePoints = %(tsneJson)s;
var cellToCluster = %(clusterJson)s;
var tsneColors = %(tsneColorsJson)s;

// set to a default value and add listener
$("#tsneGroupSelect").val("%(colorBy)s");
addSeuratGraph();
$("#tsneGroupSelect").change(addSeuratGraph);
</script>

    """ % locals())

    writeSeuratGenes(htmlFh, geneToSym, clusterGenesPath, 0.3, 0.5, outDir)
    if jsonOutFname is not None:
         writeNewJson(coordList, cellInfo, cellInfoTagToLabel, jsonOutFname)

def writeNewJson(coordList, cellInfo, cellInfoTagToLabel, jsonFname):
    " write t-SNE data in a more modern format to a separate json file "
    logging.info("Writing Seurat data in new JSON format to %s" % jsonFname)
    # coordinates
    #print cellInfo
    #print cellInfoTagToLabel
    data = {}
    data["coords"] = coordList

    # list of fields
    #cellInfoTags = cellInfoTagToLabel.keys()
    fields = list(sorted(cellInfo.values()[0].keys()))
    #for cit in tags:
        #fields.append(cellInfoTagToLabel[cit])

    # meta data - make a list
    colVals = defaultdict(list)
    for cellId, infoDict in cellInfo.iteritems():
        for field in fields:
            if field=="meta":
                continue
            val = infoDict[field]
            colVals[field].append(val)
    
    # filter down the fields to things that make sense
    filtFields = []
    for field, vals in colVals.iteritems():
        valCount = len(set(vals))
        if valCount > 3 and not valCount==len(vals):
            filtFields.append(field)

    logging.info("Exporting %d fields to json: %s" % (len(filtFields), filtFields))
    data["metaFields"] = filtFields

    # extract these fields
    meta = {}
    for cellId, infoDict in cellInfo.iteritems():
        row = []
        for field in filtFields:
            row.append(infoDict.get(field, ""))
        meta[cellId] = row
    data["meta"] = meta


    #json.dump(data, open(jsonFname, "w"), sort_keys=True, indent=4)
    json.dump(data, open(jsonFname, "w"))

def parseSeuratGenes(clusterGenesPath, minPower=0.6):
    " parse the seurat cluster specific genes table, return dict clusterId -> list of dicts sorted by AUC "
    logging.info("Parsing %s" % clusterGenesPath)
    clusterToGenes = defaultdict(list)
    # rowName        myAUC   avg_diff        power   cluster gene
    for row in lineFileNextRow(clusterGenesPath):
        gene = row.rowName
        rowDict = row._asdict()
        rowDict["avg_diff"] = "%0.2f" % (float(rowDict["avg_diff"]))
        clusterToGenes[row.cluster].append(row._asdict())
    return clusterToGenes

def parseAnnots(fnames):
    " parse key-val file and return as dict "
    d = defaultdict(list)
    for fname in fnames:
        for line in open(fname):
            key, val = line.rstrip("\n").split("\t")
            d[key].append(val)

    d2 = {}
    for key, vals in d.iteritems():
        d2[key] = ", ".join(vals)

    return d2

def writeSeuratGenes(htmlFh, geneToSym, clusterGenesPath, minPower, minAuc, outDir):
    " write a table with the seurat cluster-specific genes "
    if getsize(clusterGenesPath)==0:
        logging.info("Seurat did not return cluster-specific genes, skipping this part")
        return

    htmlFh.write('<a href="seuratGenes.html">Seurat cluster-specific genes</a>')

    outFname = join(outDir, "seuratGenes.html")
    logging.info("Writing %s" % outFname)
    htmlFh = open(outFname, "w")
    writeHeader(htmlFh, "Seurat Cluster Genes")
    htmlFh.write("<h2>Seurat cluster-specific genes with power &gt; %.2f and AUC &gt; %.2f </h2>\n" % (minPower, minAuc))

    clusterToGenes = parseSeuratGenes(clusterGenesPath)
    geneAnnots = parseAnnots([getCirmStaticFile("annotation/panther.tab"),getCirmStaticFile("annotation/markers/pollen/all.tab")])

    # filter rows
    filtClusterGenes = []
    for clusterId, rowDicts in clusterToGenes.iteritems():
        filtRowDicts = []
        for rowDict in rowDicts:
            if float(rowDict["power"]) > minPower and float(rowDict["myAUC"]) > minAuc:
                filtRowDicts.append(rowDict)
        filtClusterGenes.append( (clusterId, filtRowDicts) )

    # sort by cluster Id
    filtClusterGenes.sort()

    # write TOC
    htmlFh.write("<ul>\n")
    for clusterId, rowDicts in filtClusterGenes:
        htmlFh.write("<li><a href='#Cluster%s'>Cluster%s (%d genes)</a></li>\n" % (clusterId, clusterId, len(rowDicts)))
    htmlFh.write("</ul>\n")

    # output tables
    for clusterId, rowDicts in filtClusterGenes:
        htmlFh.write("<h3 id='Cluster%s'>Cluster%s</h3>\n" % (clusterId, clusterId))
        htmlFh.write("<table>\n")
        htmlFh.write("<tr><th>Gene</th><th>Avg. Diff</th><th>Power</th><th>AUC</th><th>Gene Type</th></tr>\n")
        for rowDict in rowDicts:
            htmlFh.write("<tr>\n")
            ensId = rowDict["gene"].split("/")[-1]
            geneName = geneToSym.get(ensId, ensId)
            htmlFh.write("<td><a href='http://uswest.ensembl.org/Homo_sapiens/Gene/Summary?g=%s'>%s</a></td>\n" % (ensId, geneName))
            htmlFh.write("<td>%0.2f</td>\n" % float(rowDict["avg_diff"]))
            for key in ["power", "myAUC"]:
                htmlFh.write("<td>%s</td>\n" % (rowDict[key]))

            geneAnnot = geneAnnots.get(geneName, "")
            htmlFh.write("<td>%s</td>\n" % geneAnnot)
                
            htmlFh.write("</tr>\n")
        htmlFh.write("</table>\n")

def readTagLabels(fname):
    " read key-val tab-sep file "
    if fname==None:
        return {}

    d = {}
    for line in open(fname):
        tag, label = line.rstrip("\n").split("\t")[:2]
        label = label.decode("latin1")
        label = label.strip('"')
        if len(label)>35:
            label = label[:35]+"..."
        d[tag] = label
    return d

def writeCellInfo(htmlFh, cellInfo, tagLabels):
    " write a <script> section with the cell info json objects "
    # create a nice label for each cellInfo key as JSON
    cellInfoTagToLabel = {}

    for extName, intName, label in cellInfoTags:
        cellInfoTagToLabel[intName] = label
    for tag, label in tagLabels.iteritems():
        cellInfoTagToLabel[tag] = label.encode('utf8')
        
    cellInfoTagToLabel["cellId"] = "SRA ID"

    cellInfoLabelJson = json.dumps(cellInfoTagToLabel)
    cellInfoJson = json.dumps(cellInfo)
    htmlFh.write("""
<script>
  var cellTagLabels = %(cellInfoLabelJson)s;
  var cellInfo = %(cellInfoJson)s;
</script>""" % locals())
    return cellInfoTagToLabel

#def parseManifest(maniFname):
    #""" manifest is a tab-sep file,
    #typical header is "#file   meta    format  output"
    #return as dict output -> rows
    #"""
    #ret = defaultdict(list)
    #for row in lineFileNextRow(maniFname):
        #ret[row.output].append(row)
    #return ret

def mustRunCmd(cmd, delOnError=None):
    " run cmd through os.system with debugging "
    logging.debug("Running cmd: %s" % cmd)
    ret = os.system(cmd)
    if ret != 0:
        if delOnError:
            os.remove(delOnError)
        errAbort("Could not run command: %s" % cmd)

def makeCacheFname(fname, cacheDir, prefix, suffix):
    " create a unique filename in cacheDir for fname, includes file size "
    fsize = getsize(fname)
    baseName = prefix+"_"+ fname.replace("/", "_") + ("_%dbytes"%fsize)
    tmpFname = join(cacheDir, "snpEff_"+baseName+"."+suffix)
    return tmpFname

#def runSnpEff(fname, tmpDir):
    #cacheFname = makeCacheFname(fname, tmpDir, "snpEff")
    #if isfile(cacheFname):
        #logging.info("%s already exists, not running SnpEFF" % cacheFname)
        #return
#
    #logging.info("Running snpEff on %s" % fname)
    #baseCmd = "java -Xmx4g -jar /pod/home/max/software/snpEff/snpEff.jar GRCh37.75 "
    #cmd = baseCmd+" > "+cacheFname
    #mustRunCmd(cmd, delOnError=cacheFname)

#def runVai(fname, cacheDir):
    #cacheFname = makeCacheFname(fname, cacheDir, "vai")
    #cmd = "zcat %s | vcf-sort | ~/vai.pl hg19 /dev/stdin --hgVai=/usr/local/apache/cgi-bin-max/hgVai -geneTrack=knownGene > %s" % (fname, cacheFname)
    #mustRunCmd(cmd, delOnError=cacheFname)

def parseAnnovar(fname, cacheDir):
    """" parse a directory full of annoVar files. Return a dict with:
    - varCount - number of variants in total
    - funcCount - most common types of variant changes and their count (list of (type, count))
    - geneChanges - dict with geneId -> list of hgvs expressions
    - mutToCoord - dict with hgvs -> (chrom, start, end)
   
    This parsing is pretty slow, so a copy of the results are kept under cacheDir.
    """
    logging.info("Parsing annovar file %s" % fname)
    cacheFname = makeCacheFname(fname, cacheDir, "annovarSummary", "json")
    if isfile(cacheFname):
        logging.info("Using cached results from %s" % cacheFname)
        return json.load(open(cacheFname))

    funcCounts = Counter()
    total = 0
    symToHgvs = defaultdict(list)
    mutToCoord = dict()

    for row in lineFileNextRow(fname):
        total+=1
        exonFunc = row.ExonicFunc_refGene
        if exonFunc==".":
            exonFunc = "non-coding"
        funcCounts[exonFunc] += 1

        if exonFunc=="nonsynonymous SNV" or "frameshift" in exonFunc:
            aaChg = row.AAChange_refGene
            fs = aaChg.split(":")
            sym = fs[0]
            protHgvs = fs[-1]
            chromStartEnd = (row[0], row[1], row[2])
            symToHgvs[sym].append(protHgvs)
            mutToCoord[protHgvs] = chromStartEnd

    summ = {}
    summ["varCount"] = total
    summ["funcCounts"] = funcCounts.most_common()
    summ["geneChanges"] = dict(symToHgvs)
    summ["mutToCoord"] = mutToCoord

    logging.info("Writing cache to %s" % cacheFname)
    json.dump(summ, open(cacheFname, "wb"))

    return summ

def allSubFiles(dirName, suffix=None):
    " yield all file paths under dirName "
    for root, dirs, files in os.walk(dirName, followlinks=True):
        for name in files:
            if suffix is not None and not name.endswith(suffix):
                continue
            yield (os.path.join(root, name))

def findInputFiles(maniFname, pipeline):
    """ return file names """
    """ this is done by parsing the manifest and returning a dict with  """
    """ meta -> dict "variants"/"alignments"/"variantAnnotations" -> fname """
    assert(maniFname!=None)
    assert(pipeline!=None)
    metaToFname = defaultdict(dict)
    for row in lineFileNextRow(maniFname):
        if row.pipeline!=pipeline:
            continue
        assert(row.output not in metaToFname[row.meta])
        metaToFname[row.meta][row.output] = row.file
    return metaToFname

    #bamFnames = []
    #for meta in metas:
        #print meta, metaToFname[meta]
        #bamFnames.append(metaToFname[meta]["alignments"])
    #return bamFnames

def writeVcfSummary(htmlFh, trackDbFname, cellInfo, annovarDir, maniFname, pipeline, cacheDir):
    " summarize a list of annovar files "
    htmlFh.write("<h3>Mutated genes</h3>\n")
    htmlFh.write("<table>\n")

    annovarCacheDir = join(cacheDir, "annovar")
    if not isdir(annovarCacheDir):
        os.makedirs(annovarCacheDir)
    #maniInfo = parseManifest(maniFname)
    #for row in maniInfo["variants"]:
    #fname = row.file
    #fnames = list(allSubFiles(annovarDir, "_multianno.txt"))
    #metas = [basename(f).split(".")[0] for f in fnames]
    metaToFnames = findInputFiles(maniFname, pipeline)

    sampleNames = []
    allSumms = {}
    mutToSamples = defaultdict(set)
    nameToMuts = defaultdict(set) # datasetname -> set of coding mutations
    mutToCoord = {} # mut name -> (chrom, start, end)
    for i, sampleName in enumerate(metaToFnames):
        sampleFnames = metaToFnames[sampleName]
        if "variantAnnotations" not in sampleFnames:
            print "No variantAnnotations for %s" % sampleName
            continue
        fname = sampleFnames["variantAnnotations"]
        logging.info("processing %d out of %d files" % (i, len(metaToFnames)))
        vcfSumm = parseAnnovar(fname, annovarCacheDir)
        sampleName = basename(fname).split(".")[0]
        mutToCoord.update(vcfSumm["mutToCoord"]) # update the mutation -> coordinate mapping

        sampleNames.append(sampleName)
        allSumms[sampleName] = vcfSumm
        #allGenes.update(vcfSumm["geneChanges"])

        flatChanges = set()

        # create a mapping (gene,change) -> list of sampleNames
        for gene, muts in vcfSumm["geneChanges"].iteritems():
            for mut in muts:
                mutToSamples[(gene, mut)].add(sampleName)
            flatChanges.add( gene+":"+mut )
        nameToMuts[sampleName] = flatChanges

    # create dict sample -> number of new mutations
    # and dict sample ->  new mutations
    baseLine = nameToMuts["HDF_fib"]
    del sampleNames[sampleNames.index("HDF_fib")]

    newMutCounts = {}
    sampleToNewMuts = {}
    allMuts = set()
    for sampleName, muts in nameToMuts.iteritems():
        newMuts = muts-baseLine
        print "sample %s: %d mutations, %d new" % (sampleName, len(muts), len(newMuts))
        newMutCounts[sampleName] = len(newMuts)
        sampleToNewMuts[sampleName] = newMuts
        allMuts.update(newMuts)

    htmlFh.write("Total number of coding mutations in at least one sample: %d" % len(allMuts))

        #if fname.endswith(".vcf.gz"):
            #runSnpEff(fname, annovarCacheDir)
            #runVai(fname, annovarCacheDir)

    # print intersections
    #for a, b in itertools.combinations(nameToMuts, 2):
        #shared = (nameToMuts[a]-baseLine).intersection((nameToMuts[b]-baseLine))
        #row = [a,b,len(shared)]
        #row = [str(x) for x in row]
        #print "\t".join(row)

    sampleNames.sort()

    # find all mutations that appear everywhere
    #commonMuts = set()
    #for (gene, mut), mutSamples in mutToSamples.iteritems():
        #if len(mutSamples)==len(sampleNames):
            #commonMuts.add( (gene, mut) )
    
    #changes = {}
    #allGenes = set()
    #for sampleName, muts in newMuts.iteritems():
        ## remove all mutations that appear in the founder cell
        #newChanges = defaultdict(list)
        #for gene, muts in vcfSumm["geneChanges"].iteritems():
            #for mut in muts:
                #if gene+":"+mut in baseLine:
                    #continue
                #else:
                    #newChanges[gene].append(mut)
        # copy the old object, but use the filtered list
        #newSumm = {}
        #newSumm.update(vcfSumm)
        #newSumm["geneChanges"] = newChanges
        #changes[sampleName] = newChanges
        #allGenes.update(newChanges.keys())

    # write the table headers
    htmlFh.write("<table style='table-layout:fixed'>\n")
    htmlFh.write("<thead><tr>\n")
    htmlFh.write("<th style='width:100px'>Mutation</th>\n")
    htmlFh.write("<th style='width:100px'>BAM Reads</th>\n")
    htmlFh.write("<th style='width:50px'>Sample Count</th>\n")
    for sampleName in sampleNames:
        htmlFh.write("<th style='width:60px'>")
        htmlFh.write("%s<br>" % sampleName.replace("_", " "))
        #htmlFh.write("<small>Count: %s</small>\n" % allSumms[sampleName]["varCount"])
        htmlFh.write("<small>%s</small>\n" % newMutCounts[sampleName])
        htmlFh.write("</th>\n")
    htmlFh.write("</tr></thead>\n")

    # now create table rows with cell colors and sort them by sample count
    sampleToColor = {
    "Env" : "light-blue",
    "Ev" : "yellow",
    "Lnv" : "orange",
    "Lv" : "light-green",
    "LvL" : "light-red"
    }

    rows = []
    for mutName in sorted(allMuts):
        row = [0, mutName] # 0 is just a placeholder for sampleCount
        sampleCount = 0
        for sampleName in sampleNames:
            if mutName in sampleToNewMuts[sampleName]:
                row.append("black")
                sampleCount += 1
            else:
                sampleType = sampleName.split("_")[0]
                if sampleName.startswith("Lv L"):
                    sampleType = "LvL"
                color = sampleToColor[sampleType]
                row.append(color)
        row[0] = sampleCount
        rows.append(row)

    rows.sort(reverse=True)

    tdbfh = open(trackDbFname, "w")
    tdbfh.write("""hub CIRM-CDW
shortLabel Baldwin 12 Paired IPS Cell Lines
longLabel Baldwin 12 Paired IPS Cell Lines
genomesFile trackDb.txt
email max@soe.ucsc.edu
descriptionUrl trackDb.txt
genome hg19
trackDb trackDb.txt

""")

    #ctLines = []
    token = "M0iPQUvDcBQQH09S"
    for meta, typeFnames in metaToFnames.iteritems():
        bamFname = typeFnames["alignments"]
        baiFname = typeFnames["index"]
        bamUrl = join("http://cirm-01.local/cgi-bin/cdwGetFile?token=%s&path=%s" % (token, bamFname))
        baiUrl = join("http://cirm-01.local/cgi-bin/cdwGetFile?token=%s&path=%s" % (token, baiFname))
        trackName = meta
        trackDesc = meta
        #ctLine = 'track name="{0}" description="{1}" bigDataUrl={2} bigDataIndex={3} type=bam'.format(trackName, trackDesc, bamUrl, baiUrl)
        #ctLines.append(ctLine)

        tdbfh.write("track %s\n" % trackName)
        tdbfh.write("shortLabel %s\n" % trackName)
        tdbfh.write("longLabel %s\n" % trackDesc)
        tdbfh.write("bigDataUrl %s\n" % bamUrl)
        tdbfh.write("bigDataIndex %s\n" % baiUrl)
        tdbfh.write("visibility dense\n")
        tdbfh.write("type bam\n")
        tdbfh.write("\n")

    print "Wrote %s" % trackDbFname
    tdbfh.close()
    #ctText = "\n".join(ctLines)
    #ctText = urllib.quote_plus(ctText)

    for row in rows:
        sampleCount = row[0]
        mutName = row[1]
        colors = row[2:]

        htmlFh.write("<tr>\n")
        chrom, start, end = mutToCoord[mutName.split(":")[1]]

        if len(mutName)>20:
            mutName = mutName[:20]+"..."

        htmlFh.write("<td>%s</td>\n" % mutName)
        position = "%s:%d-%d" % (chrom, int(start)-1, int(end))

        #htmlFh.write("<td><a href='http://cirm-01.local/cgi-bin/hgTracks?db=hg19&position={0}&hgt.customText={1}'>Reads</a></td>\n".format(position, ctText))
        hubUrl = "http://cirm-01-max.pod/datasetSummary/baldwin12PairedIpsCellLines/trackDb.txt"
        htmlFh.write("<td><a href='http://cirm-01.local/cgi-bin/hgTracks?db=hg19&position={0}&hubUrl={1}'>Reads</a></td>\n".format(position, hubUrl))
        # http://cirm-01-max/cgi-bin/hgTracks?hgsid=4423_zFP0CyERJhSChPvZ28dtgoe5bobK&db=hg19
        htmlFh.write("<td>%d</td>\n" % sampleCount)

        for color in colors:
            htmlFh.write("<td style='background-color: %s'></td>\n" % color)
        htmlFh.write("</tr>\n")
    htmlFh.write("</table>\n")

def pasteSummarySection(summDescFn, htmlFh):
    " insert manually generated summary section into html file "
    htmlFh.write(open(summDescFn).read())

def writeDatasetPage(datasetId, matrixFname, tagTabFname, kallistoMetaTab, annovarDir, geneToSym, tagLabels, outDir, seuratOut):
    """
    write an html page and various annex .js and .css files
    to outDir. Extracted data are written to summary.tmp in outDir,
    so re-generating the page itself is quick.
    To force a data rebuild, run the script with -f or delete
    the summary.tmp file.
    """
    if not isdir(outDir):
        os.mkdir(outDir)

    tmpDir = join(outDir, "tmp")
    if not isdir(tmpDir):
        os.mkdir(tmpDir)

    markerFname = options.markerFile
    useSyms = options.useSyms

    cellTypeJsonFname = join(outDir, "cellTypes.json")
    transHistJsonFname = join(outDir, "transCounts.json")
    cacheFname = join(outDir, "summary.tmp")

    global seuratAllGenes
    seuratAllGenes = options.seuratAllGenes

    if matrixFname=="none":
        typeToGenes, geneToTypes = parseMarker(markerFname, useSyms)
        sumData = writeCacheFile(datasetId, {}, tagTabFname, kallistoMetaTab, geneToTypes, outDir, cacheFname)

    else:
        if not isfile(cacheFname) or options.force or not isfile(transHistJsonFname):
            typeToGenes, geneToTypes = parseMarker(markerFname, useSyms)

            fullExprMatrix = parseMatrix(matrixFname)
            cellToType = matrixToCellType(fullExprMatrix, typeToGenes, outDir)

            # write json plotting data
            writeCellTypeJson(cellToType, cellTypeJsonFname)
            writeExprTransHistJson(fullExprMatrix, transHistJsonFname)
                
            sumData = writeCacheFile(datasetId, fullExprMatrix, tagTabFname, kallistoMetaTab, geneToTypes, outDir, cacheFname)

        else:
            sumData = json.load(open(cacheFname))
            fullExprMatrix = None # load the matrix later if needed

        markerGenes = set(sumData["markers"])

        #writeCellPage(outDir, sumData)
        cellIds, smallExprMatrix = writeHeatmapPage(outDir, sumData)
        #smallMatrix = transMatrix(cellIds, smallExprMatrix)

    # write main html page
    htmlFh = open(join(outDir, "index.html"), "w")

    summDescFn = "summaryDesc.html"
    writeHeader(htmlFh, "CIRM Dataset Summary: "+datasetId)
    if not isfile(summDescFn):
        generateSummarySection(sumData, htmlFh)
        cellInfo = sumData["cellInfo"]
        cellInfoTagToLabel = writeCellInfo(htmlFh, cellInfo, tagLabels)
    else:
        pasteSummarySection(summDescFn, htmlFh)
        cellInfo = {}
        cellInfoTagToLabel = None

    # only do this part if we have rna-seq data
    if matrixFname!="none":
        writeSeuratSection(htmlFh, matrixFname, fullExprMatrix, cellInfo, geneToSym, tagLabels, outDir, cellInfoTagToLabel, seuratOut)
        if not options.onlySeurat:
            htmlFh.write('''Details: <a href="heatmap.html">Expression Heatmap</a>''')
            writePcaSection(htmlFh, matrixFname, fullExprMatrix, markerGenes, cellInfo, outDir)
            htmlFh.write("<p>Details: <a href='cellInfo.html'>Single Cell Details Table</a>\n")
            writeTransCountSection(htmlFh, transHistJsonFname)


        #if not options.onlySeurat:
            #writeCellTypeSection(htmlFh, cellTypeJsonFname)

    if annovarDir:
        tmpDir = makeTmpDirFor(outDir)
        trackDbFname = join(outDir, "trackDb.txt")
        writeVcfSummary(htmlFh, trackDbFname, cellInfo, annovarDir, maniFname, pipeline, tmpDir)
    #htmlFh.write("<h2>Dendrogram</h2>")
    #htmlFh.write("<a href='http://hgwdev-ceisenhart.soe.ucsc.edu/~ceisenhart/radialDendrograms/cirm/quake/quakeBrainGeo1/sleuthGeneClusters2.html'>Link to dendrogram page</a>")
    # activate the collapser
    #htmlFh.write("""
#<script>
#$(document).ready(function(){
    #$('.collapse').collapser({
        #mode: 'words',
        #truncate: 60
    #});
#});
#</script>""")
    
    htmlFh.write("""
</div><!-- mainContent -->
</body></html>""")

    logging.info("Wrote %s" % htmlFh.name)

def main():
    global options
    args, options = parseCmdLine()
    mysqlDatasetId, matrixFname, tagTabFname, outDir = args

    if tagTabFname=="none":
        kallistoMeta = None
    else:
        kallistoMeta = tagTabFname.replace(".tab", ".kallisto.tab").replace(".tsv", ".kallisto.tab")
        assert(kallistoMeta!=tagTabFname) # the meta data file has to end with .tab or .tsv, otherwise something is fishy

    geneToSym = readGeneToSym(options.symTable)
    tagLabels = readTagLabels(options.tagLabels)

    #maniFname = options.maniFname
    annovarDir = options.annovarDir

    writeDatasetPage(mysqlDatasetId, matrixFname, tagTabFname, kallistoMeta, annovarDir, geneToSym, tagLabels, outDir, options.seuratOut)

main()
