#!/usr/bin/env python2.7
# cdwCleanMetaData
"""Remove single valued meta data terms and redundant meta data terms"""
import os
import sys
import collections
import argparse

def parseArgs(args):
    """
    Parse the command line arguments.
    """
    parser= argparse.ArgumentParser(description = __doc__)
    parser.add_argument ("inputFile",
    help = " The input file. ",
    type = argparse.FileType("r"))
    parser.add_argument ("outputFile",
    help = " The output file. ",
    type =argparse.FileType("w"))
    if (len(sys.argv) == 1):
        parser.print_help()
        exit(1)
    options = parser.parse_args()
    return options

def getRedundancyTable(splitLine):
    """
    Keep track of the term pairs that have identical values. Start
    by making a list of all possible term pairs, the list will 
    be shortened as pairs are eliminated. 
    """
    n = len(splitLine)
    redTable = []
    for i in range(n): 
        for j in range(n-i-1): 
             redTable.append((i,n-j-1))
    return redTable

def getDuplicativeTable(splitLine): 
    """
    Keep track of the terms that have a single value. Start by making
    a dictionary that maps each column to its starting term. Terms will
    be eliminated if the value deviates.  
    """
    n = len(splitLine)
    dupTable = dict()
    for i in range(n): 
        dupTable.setdefault(i,splitLine[i]) 
    return dupTable

def updateRedundancyTable(redTable, splitLine):
    """
    Update the term pairs that have identical values, only
    consider term pairs that have had identical values up to this point to keep
    things lightweight and fast. 
    """
    n = len(splitLine)
    newTable = []
    for i in range(n): 
        for j in range(n-i-1): 
            if splitLine[i] == splitLine[n-j-1]:
                newTable.append((i, n-j-1))
    return newTable

def updateDuplicativeTable(dupTable, splitLine):
    """
    Update the terms that have a single value. Only consider terms
    that have had the same value so far to keep things fast. 
    """
    n = len(dupTable.keys())
    newTable = dict()
    for k,v in dupTable.items():
        if (dupTable.get(k) == splitLine[k]):
            newTable.setdefault(k, splitLine[k])
    return newTable   


def main(args):
    """
    Initializes options and calls other functions. Removes 'bad' columns from 
    a meta data file. Designed to be fairly fast/low footprint.  
    """
    options = parseArgs(args)
    firstLine = True
    secondLine = True
    totalCols = 0
    # Go through the meta data file once to determine which columns are bad. 
    for line in options.inputFile: 
        splitLine = line.strip("\n").split("\t")
        if firstLine:
            firstLine = False 
            totalCols = len(splitLine)
            redTable = getRedundancyTable(splitLine)
            continue
        if secondLine:
            secondLine = False
            dupTable = getDuplicativeTable(splitLine)
        if (redTable != []): redTable = updateRedundancyTable(redTable, splitLine)
        if (dupTable != {}): dupTable = updateDuplicativeTable(dupTable, splitLine)
    if (redTable == {} and dupTable == []):
            os.system("cp %s %s"%(options.inputFile.name, options.outputFile.name))
            exit(0) 
    # Make a list of bad columns and tell the user. 
    badColList = []
    if (redTable != []):
        for item in redTable: 
            print ("These two columns have identical values; %i %i.  Column %i was kept."%(item[0]+1,item[1]+1, item[0]+1)) 
            badColList.append(item[1])
    if (dupTable != {}): 
        for k,v in dupTable.items(): 
            print ("The column %i had a constant value and was removed."%(k+1))
            badColList.append(k)
    
    # Cut out the good columns and put them into the output file.  
    cmd = "cut -f " 
    for n in range(totalCols): 
        if (n in badColList): continue
        cmd += str(n + 1) + "," 
    cmd = cmd.strip(",") + " " + options.inputFile.name + " > " + options.outputFile.name
    os.system(cmd)
    

if __name__ == "__main__" : 
    sys.exit(main(sys.argv))
