diff --git a/TreeCluster.py b/TreeCluster.py index a952458..fd9033d 100755 --- a/TreeCluster.py +++ b/TreeCluster.py @@ -10,11 +10,14 @@ # check if user is just printing version if '--version' in argv: - print("TreeCluster version %s" % VERSION); exit() + print(f"TreeCluster version {VERSION}") + exit() # merge two sorted lists into a sorted list def merge_two_sorted_lists(x,y): - out = list(); i = 0; j = 0 + out = [] + i = 0 + j = 0 while i < len(x) and j < len(y): if x[i] < y[j]: out.append(x[i]); i+= 1 @@ -33,7 +36,7 @@ def merge_multi_sorted_lists(lists): if len(lists[l]) != 0: pq.put((lists[l][0],l)) inds = [1 for _ in range(len(lists))] - out = list() + out = [] while not pq.empty(): d,l = pq.get(); out.append(d) if inds[l] < len(lists[l]): @@ -44,9 +47,9 @@ def merge_multi_sorted_lists(lists): # get the median of a sorted list def median(x): if len(x) % 2 != 0: - return x[int(len(x)/2)] + return x[len(x) // 2] else: - return (x[int(len(x)/2)]+x[int(len(x)/2)-1])/2 + return (x[len(x) // 2] + x[len(x) // 2 - 1]) / 2 # get the average of a list def avg(x): @@ -59,8 +62,9 @@ def p_to_jc(d,seq_type): # cut out the current node's subtree (by setting all nodes' DELETED to True) and return list of leaves def cut(node): - cluster = list() - descendants = Queue(); descendants.put(node) + cluster = [] + descendants = Queue() + descendants.put(node) while not descendants.empty(): descendant = descendants.get() if descendant.DELETED: @@ -128,7 +132,7 @@ def pairwise_dists_below_thresh(tree,threshold): # split leaves into minimum number of clusters such that the maximum leaf pairwise distance is below some threshold def min_clusters_threshold_max(tree,threshold,support): leaves = prep(tree,support) - clusters = list() + clusters = [] for node in tree.traverse_postorder(): # if I've already been handled, ignore me if node.DELETED: @@ -179,7 +183,7 @@ def min_clusters_threshold_med_clade(tree,threshold,support): if node.is_leaf(): node.med_pair_dist = 0 node.leaf_dists = [0] - node.pair_dists = list() + node.pair_dists = [] else: children = list(node.children) l_leaf_dists = [d + children[0].edge_length for d in children[0].leaf_dists] @@ -198,7 +202,9 @@ def min_clusters_threshold_med_clade(tree,threshold,support): del c.leaf_dists; del c.pair_dists # perform clustering - q = Queue(); q.put(tree.root); roots = list() + q = Queue() + q.put(tree.root) + roots = [] while not q.empty(): node = q.get() if node.med_pair_dist <= threshold: @@ -210,7 +216,7 @@ def min_clusters_threshold_med_clade(tree,threshold,support): # if verbose, print the clades defined by each cluster if VERBOSE: for root in roots: - print("%s;" % root.newick(), file=stderr) + print(f"{root.newick()};", file=stderr) return [[str(l) for l in root.traverse_leaves()] for root in roots] # average leaf pairwise distance cannot exceed threshold, and clusters must define clades @@ -231,7 +237,9 @@ def min_clusters_threshold_avg_clade(tree,threshold,support): node.avg_pair_dist = node.total_pair_dist/((node.num_leaves*(node.num_leaves-1))/2) # perform clustering - q = Queue(); q.put(tree.root); roots = list() + q = Queue() + q.put(tree.root) + roots = [] while not q.empty(): node = q.get() if node.avg_pair_dist <= threshold: @@ -243,7 +251,7 @@ def min_clusters_threshold_avg_clade(tree,threshold,support): # if verbose, print the clades defined by each cluster if VERBOSE: for root in roots: - print("%s;" % root.newick(), file=stderr) + print(f"{root.newick()};", file=stderr) return [[str(l) for l in root.traverse_leaves()] for root in roots] # total branch length cannot exceed threshold, and clusters must define clades @@ -258,7 +266,9 @@ def min_clusters_threshold_sum_bl_clade(tree,threshold,support): node.total_bl = sum(c.total_bl + c.edge_length for c in node.children) # perform clustering - q = Queue(); q.put(tree.root); roots = list() + q = Queue() + q.put(tree.root) + roots = [] while not q.empty(): node = q.get() if node.total_bl <= threshold: @@ -270,13 +280,13 @@ def min_clusters_threshold_sum_bl_clade(tree,threshold,support): # if verbose, print the clades defined by each cluster if VERBOSE: for root in roots: - print("%s;" % root.newick(), file=stderr) + print(f"{root.newick()};", file=stderr) return [[str(l) for l in root.traverse_leaves()] for root in roots] # total branch length cannot exceed threshold def min_clusters_threshold_sum_bl(tree,threshold,support): leaves = prep(tree,support) - clusters = list() + clusters = [] for node in tree.traverse_postorder(): if node.is_leaf(): node.left_total = 0; node.right_total = 0 @@ -310,7 +320,7 @@ def min_clusters_threshold_sum_bl(tree,threshold,support): # single-linkage clustering using Metin's cut algorithm def single_linkage_cut(tree,threshold,support): leaves = prep(tree,support) - clusters = list() + clusters = [] # find closest leaf below (dist,leaf) for node in tree.traverse_postorder(): @@ -372,7 +382,7 @@ def single_linkage_cut(tree,threshold,support): # single-linkage clustering using Niema's union algorithm def single_linkage_union(tree,threshold,support): leaves = prep(tree,support) - clusters = list() + clusters = [] # find closest leaf below (dist,leaf) for node in tree.traverse_postorder(): @@ -432,7 +442,9 @@ def min_clusters_threshold_max_clade(tree,threshold,support): node.max_pair_dist = max([c.max_pair_dist for c in node.children] + [node.leaf_dist + second_max_leaf_dist]) # perform clustering - q = Queue(); q.put(tree.root); roots = list() + q = Queue() + q.put(tree.root) + roots = [] while not q.empty(): node = q.get() if node.max_pair_dist <= threshold: @@ -444,7 +456,7 @@ def min_clusters_threshold_max_clade(tree,threshold,support): # if verbose, print the clades defined by each cluster if VERBOSE: for root in roots: - print("%s;" % root.newick(), file=stderr) + print(f"{root.newick()};", file=stderr) return [[str(l) for l in root.traverse_leaves()] for root in roots] # pick the threshold between 0 and "threshold" that maximizes number of (non-singleton) clusters @@ -467,7 +479,7 @@ def argmax_clusters(method,tree,threshold,support): # cut all branches longer than the threshold def length(tree,threshold,support): leaves = prep(tree,support) - clusters = list() + clusters = [] for node in tree.traverse_postorder(): # if I've already been handled, ignore me if node.DELETED: @@ -498,7 +510,9 @@ def length_clade(tree,threshold,support): node.max_bl = max([c.max_bl for c in node.children] + [c.edge_length for c in node.children]) # perform clustering - q = Queue(); q.put(tree.root); roots = list() + q = Queue() + q.put(tree.root) + roots = [] while not q.empty(): node = q.get() if node.max_bl <= threshold: @@ -510,13 +524,13 @@ def length_clade(tree,threshold,support): # if verbose, print the clades defined by each cluster if VERBOSE: for root in roots: - print("%s;" % root.newick(), file=stderr) + print(f"{root.newick()};", file=stderr) return [[str(l) for l in root.traverse_leaves()] for root in roots] # cut tree at threshold distance from root (clusters will be clades by definition) (ignores support threshold if branch is below cutting point) def root_dist(tree,threshold,support): leaves = prep(tree,support) - clusters = list() + clusters = [] for node in tree.traverse_preorder(): # if I've already been handled, ignore me if node.DELETED: @@ -540,7 +554,9 @@ def root_dist(tree,threshold,support): # cut tree at threshold distance from the leaves (if tree not ultrametric, max = distance from furthest leaf from root, min = distance from closest leaf to root, avg = average of all leaves) def leaf_dist(tree,threshold,support,mode): modes = {'max':max,'min':min,'avg':avg} - assert mode in modes, "Invalid mode. Must be one of: %s" % ', '.join(sorted(modes.keys())) + assert ( + mode in modes + ), f"Invalid mode. Must be one of: {', '.join(sorted(modes.keys()))}" dist_from_root = modes[mode](d for u,d in tree.distances_from_root(internal=False)) - threshold return root_dist(tree,dist_from_root,support) def leaf_dist_max(tree,threshold,support): @@ -576,13 +592,29 @@ def leaf_dist_avg(tree,threshold,support): parser.add_argument('-o', '--output', required=False, type=str, default='stdout', help="Output File") parser.add_argument('-t', '--threshold', required=True, type=float, help="Length Threshold") parser.add_argument('-s', '--support', required=False, type=float, default=float('-inf'), help="Branch Support Threshold") - parser.add_argument('-m', '--method', required=False, type=str, default='max_clade', help="Clustering Method (options: %s)" % ', '.join(sorted(METHODS.keys()))) - parser.add_argument('-tf', '--threshold_free', required=False, type=str, default=None, help="Threshold-Free Approach (options: %s)" % ', '.join(sorted(THRESHOLDFREE.keys()))) + parser.add_argument( + '-m', + '--method', + required=False, + type=str, + default='max_clade', + help=f"Clustering Method (options: {', '.join(sorted(METHODS.keys()))})", + ) + parser.add_argument( + '-tf', + '--threshold_free', + required=False, + type=str, + default=None, + help=f"Threshold-Free Approach (options: {', '.join(sorted(THRESHOLDFREE.keys()))})", + ) parser.add_argument('-v', '--verbose', action='store_true', help="Verbose Mode") parser.add_argument('--version', action='store_true', help="Display Version") args = parser.parse_args() - assert args.method.lower() in METHODS, "ERROR: Invalid method: %s" % args.method - assert args.threshold_free is None or args.threshold_free in THRESHOLDFREE, "ERROR: Invalid threshold-free approach: %s" % args.threshold_free + assert args.method.lower() in METHODS, f"ERROR: Invalid method: {args.method}" + assert ( + args.threshold_free is None or args.threshold_free in THRESHOLDFREE + ), f"ERROR: Invalid threshold-free approach: {args.threshold_free}" assert args.threshold >= 0, "ERROR: Length threshold must be at least 0" assert args.support >= 0 or args.support == float('-inf'), "ERROR: Branch support must be at least 0" VERBOSE = args.verbose @@ -596,20 +628,23 @@ def leaf_dist_avg(tree,threshold,support): from sys import stdout; outfile = stdout else: outfile = open(args.output,'w') - trees = list() + trees = [] for line in infile: - if isinstance(line,bytes): - l = line.decode().strip() - else: - l = line.strip() + l = line.decode().strip() if isinstance(line,bytes) else line.strip() trees.append(read_tree_newick(l)) # run algorithm - for t,tree in enumerate(trees): - if args.threshold_free is None: - clusters = METHODS[args.method.lower()](tree,args.threshold,args.support) - else: - clusters = THRESHOLDFREE[args.threshold_free](METHODS[args.method.lower()],tree,args.threshold,args.support) + for tree in trees: + clusters = ( + METHODS[args.method.lower()](tree, args.threshold, args.support) + if args.threshold_free is None + else THRESHOLDFREE[args.threshold_free]( + METHODS[args.method.lower()], + tree, + args.threshold, + args.support, + ) + ) outfile.write('SequenceName\tClusterNumber\n') cluster_num = 1 for cluster in clusters: