#!/usr/bin/env python2.7
# matrixFilter
"""Go through an expression matrix and remove rows with no, or low, expression"""
import os
import sys
import collections
import argparse

# import the UCSC kent python library
sys.path.append(os.path.join(os.path.dirname(__file__), 'pyLib'))
import common
import subprocess
import tempfile

def parseArgs(args):
    """
    Parse the command line arguments.
    """
    parser= argparse.ArgumentParser(description = __doc__)
    parser.add_argument ("inputFile",
    help = " The input file. ",
    type = argparse.FileType("r"))
    parser.add_argument ("outputFile",
    help = " The output file. ",
    type =argparse.FileType("w"))
    parser.add_argument ("--rowThreshold",
    help = " The threshold that a row must meet to be kept in the matrix ", 
    type = int)
    parser.add_argument ("--colThreshold",
    help = " The threshold that a row must meet to be kept in the matrix ", 
    type = int)
    parser.add_argument ("--verbose",
    help = " Spit out messages during runtime. ",
    action = "store_true")

    parser.set_defaults(verbose = False)
    parser.set_defaults(rowThreshold=1)
    parser.set_defaults(colThreshold=10)    
    if (len(sys.argv) == 1): 
        parser.print_help()
        exit(1)
    options = parser.parse_args()
    return options


def filterMatrix(inputFile, outputFile, scoreThreshold, verbose):
    """
    Input:
        inputFile - An opened file (r)
        outputFile - An opened file (w)
        scoreThreshold - An int
    Filter a matrices rows, the sum of the rows is compared to the
    score threshold. If the row doesn't pass the threshold it is dropped.
    """
    header = True
    totalRows = 0
    newRows = 1
    for line in inputFile:
        totalRows += 1
        if header:
            header = False
            outputFile.write(line)
            continue
        splitLine = line.split()
        lineSum = 0
        for tpmScore in splitLine[1:]:
            lineSum += float(tpmScore)
        if (lineSum >= scoreThreshold): 
            outputFile.write(line)
            newRows +=1
    if verbose: print ("Started with %i entries, ended with %i entries"%(totalRows, newRows))

def main(args):
    """
    Initialized options and calls other functions.
    """
    options = parseArgs(args)
    if options.verbose: print ("Start filtering the matrix of bad rows and columns (matrixFilter).")
    if options.verbose: print ("The options are " + " ".join(sys.argv[1:]))
    if options.verbose: print ("Start filtering the rows...")
    firstPassMatrix = tempfile.NamedTemporaryFile(bufsize=1)
    finalPassMatrix = tempfile.NamedTemporaryFile(bufsize=1)
    filterMatrix(options.inputFile, firstPassMatrix, options.rowThreshold, options.verbose)
    transposedMatrix = tempfile.NamedTemporaryFile(bufsize=1)
    if options.verbose: print ("Transpose the matrix...")
    args = ["rowsToCols", firstPassMatrix.name ,transposedMatrix.name]
    if (options.verbose): args.append("-verbose=2")
    p = subprocess.Popen(args)
    stdout, stderr = p.communicate()
    if options.verbose: print ("Start filtering the columns...")
    filterMatrix(transposedMatrix, finalPassMatrix, options.colThreshold, options.verbose)
    if options.verbose: print ("Transpose the matrix into its final form...")
    args2 = ["rowsToCols", finalPassMatrix.name , options.outputFile.name]
    if (options.verbose): args2.append("-verbose=2")
    p2 = subprocess.Popen(args2)
    stdout2, stderr2 = p2.communicate()
    
    if options.verbose: print ("Completed filtering the matrix of bad rows and columns (matrixFilter).")

if __name__ == "__main__" : 
    sys.exit(main(sys.argv))
