Coverage for binette/bin_quality.py: 100%
115 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-06 19:22 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-01-06 19:22 +0000
1#!/usr/bin/env python3
2import logging
3import os
4from collections import Counter
5from itertools import islice
6from typing import Dict, Iterable, Optional, Tuple, Iterator, Set
8import numpy as np
9import pandas as pd
10from binette.bin_manager import Bin
11from tqdm import tqdm
13# Suppress unnecessary TensorFlow warnings
14os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
15logging.getLogger("tensorflow").setLevel(logging.FATAL)
18from checkm2 import keggData, modelPostprocessing, modelProcessing # noqa: E402
21def get_bins_metadata_df(
22 bins: Iterable[Bin],
23 contig_to_cds_count: Dict[str, int],
24 contig_to_aa_counter: Dict[str, Counter],
25 contig_to_aa_length: Dict[str, int],
26) -> pd.DataFrame:
27 """
28 Generate a DataFrame containing metadata for a list of bins.
30 :param bins: A list of bin objects.
31 :param contig_to_cds_count: A dictionary mapping contig names to CDS counts.
32 :param contig_to_aa_counter: A dictionary mapping contig names to amino acid composition (Counter object).
33 :param contig_to_aa_length: A dictionary mapping contig names to total amino acid length.
34 :return: A DataFrame containing bin metadata.
35 """
37 metadata_order = keggData.KeggCalculator().return_proper_order("Metadata")
39 # Get bin metadata
40 bin_metadata_list = []
41 for bin_obj in bins:
42 bin_metadata = {
43 "Name": bin_obj.id,
44 "CDS": sum(
45 (
46 contig_to_cds_count[c]
47 for c in bin_obj.contigs
48 if c in contig_to_cds_count
49 )
50 ),
51 "AALength": sum(
52 (
53 contig_to_aa_length[c]
54 for c in bin_obj.contigs
55 if c in contig_to_aa_length
56 )
57 ),
58 }
60 bin_aa_counter = Counter()
61 for contig in bin_obj.contigs:
62 if contig in contig_to_aa_counter:
63 bin_aa_counter += contig_to_aa_counter[contig]
65 bin_metadata.update(dict(bin_aa_counter))
66 bin_metadata_list.append(bin_metadata)
68 metadata_df = pd.DataFrame(bin_metadata_list, columns=["Name"] + metadata_order)
70 metadata_df = metadata_df.fillna(0)
72 metadata_df = metadata_df.astype({col: int for col in metadata_order})
74 metadata_df = metadata_df.set_index("Name", drop=False)
75 return metadata_df
78def get_diamond_feature_per_bin_df(
79 bins: Iterable[Bin], contig_to_kegg_counter: Dict[str, Counter]
80) -> Tuple[pd.DataFrame, int]:
81 """
82 Generate a DataFrame containing Diamond feature counts per bin and completeness information for pathways, categories, and modules.
84 :param bins: A list of bin objects.
85 :param contig_to_kegg_counter: A dictionary mapping contig names to KEGG annotation counters.
86 :type bins: List
87 :type contig_to_kegg_counter: Dict[str, Counter]
88 :return: A tuple containing the DataFrame and the number of default KEGG orthologs (KOs).
89 :rtype: Tuple[pd.DataFrame, int]
90 """
91 KeggCalc = keggData.KeggCalculator()
92 defaultKOs = KeggCalc.return_default_values_from_category("KO_Genes")
94 bin_to_ko_counter = {}
95 for bin_obj in bins:
96 bin_ko_counter = Counter()
97 for contig in bin_obj.contigs:
98 try:
99 bin_ko_counter += contig_to_kegg_counter[contig]
100 except KeyError:
101 # No KO annotation found in this contig
102 continue
104 bin_to_ko_counter[bin_obj.id] = bin_ko_counter
106 ko_count_per_bin_df = (
107 pd.DataFrame(bin_to_ko_counter, index=defaultKOs).transpose().fillna(0)
108 )
109 ko_count_per_bin_df = ko_count_per_bin_df.astype(int)
110 ko_count_per_bin_df["Name"] = ko_count_per_bin_df.index
112 logging.debug("Calculating completeness of pathways and modules.")
113 logging.debug("Calculating pathway completeness information")
114 KO_pathways = KeggCalc.calculate_KO_group("KO_Pathways", ko_count_per_bin_df.copy())
116 logging.debug("Calculating category completeness information")
117 KO_categories = KeggCalc.calculate_KO_group(
118 "KO_Categories", ko_count_per_bin_df.copy()
119 )
121 logging.debug("Calculating module completeness information")
122 KO_modules = KeggCalc.calculate_module_completeness(ko_count_per_bin_df.copy())
124 diamond_complete_results = pd.concat(
125 [ko_count_per_bin_df, KO_pathways, KO_modules, KO_categories], axis=1
126 )
128 return diamond_complete_results, len(defaultKOs)
131def compute_N50(list_of_lengths) -> int:
132 """
133 Calculate N50 for a sequence of numbers.
135 :param list_of_lengths: List of numbers.
136 :param list_of_lengths: list
137 :return: N50 value.
138 """
139 list_of_lengths = sorted(list_of_lengths)
140 sum_len = sum(list_of_lengths)
142 cum_length = 0
143 length = 0
144 for length in list_of_lengths:
145 if cum_length + length >= sum_len / 2:
146 return length
147 cum_length += length
148 return length
151def add_bin_size_and_N50(bins: Iterable[Bin], contig_to_size: Dict[str, int]):
152 """
153 Add bin size and N50 to a list of bin objects.
155 :param bins: List of bin objects.
156 :param contig_to_size: Dictionary mapping contig names to their sizes.
157 """
158 for bin_obj in bins:
159 lengths = [contig_to_size[c] for c in bin_obj.contigs]
160 n50 = compute_N50(lengths)
162 bin_obj.add_length(sum(lengths))
163 bin_obj.add_N50(n50)
166def add_bin_metrics(
167 bins: Set[Bin], contig_info: Dict, contamination_weight: float, threads: int = 1
168):
169 """
170 Add metrics to a Set of bins.
172 :param bins: Set of bin objects.
173 :param contig_info: Dictionary containing contig information.
174 :param contamination_weight: Weight for contamination assessment.
175 :param threads: Number of threads for parallel processing (default is 1).
177 :return: List of processed bin objects.
178 """
179 postProcessor = modelPostprocessing.modelProcessor(threads)
181 contig_to_kegg_counter = contig_info["contig_to_kegg_counter"]
182 contig_to_cds_count = contig_info["contig_to_cds_count"]
183 contig_to_aa_counter = contig_info["contig_to_aa_counter"]
184 contig_to_aa_length = contig_info["contig_to_aa_length"]
185 contig_to_length = contig_info["contig_to_length"]
187 logging.info("Getting bin length and N50")
189 add_bin_size_and_N50(bins, contig_to_length)
191 logging.info(f"Assessing bin quality for {len(bins)}")
192 assess_bins_quality_by_chunk(
193 bins,
194 contig_to_kegg_counter,
195 contig_to_cds_count,
196 contig_to_aa_counter,
197 contig_to_aa_length,
198 contamination_weight,
199 postProcessor,
200 chunk_size=1000,
201 )
202 return bins
205def chunks(iterable: Iterable, size: int) -> Iterator[Tuple]:
206 """
207 Generate adjacent chunks of data from an iterable.
209 :param iterable: The iterable to be divided into chunks.
210 :param size: The size of each chunk.
211 :return: An iterator that produces tuples of elements in chunks.
212 """
213 it = iter(iterable)
214 return iter(lambda: tuple(islice(it, size)), ())
217def assess_bins_quality_by_chunk(
218 bins: Iterable[Bin],
219 contig_to_kegg_counter: Dict,
220 contig_to_cds_count: Dict,
221 contig_to_aa_counter: Dict,
222 contig_to_aa_length: Dict,
223 contamination_weight: float,
224 postProcessor: Optional[modelPostprocessing.modelProcessor] = None,
225 threads: int = 1,
226 chunk_size: int = 2500,
227):
228 """
229 Assess the quality of bins in chunks.
231 This function assesses the quality of bins in chunks to improve processing efficiency.
233 :param bins: List of bin objects.
234 :param contig_to_kegg_counter: Dictionary mapping contig names to KEGG counters.
235 :param contig_to_cds_count: Dictionary mapping contig names to CDS counts.
236 :param contig_to_aa_counter: Dictionary mapping contig names to amino acid counters.
237 :param contig_to_aa_length: Dictionary mapping contig names to amino acid lengths.
238 :param contamination_weight: Weight for contamination assessment.
239 :param postProcessor: post-processor from checkm2
240 :param threads: Number of threads for parallel processing (default is 1).
241 :param chunk_size: The size of each chunk.
242 """
243 with tqdm(total=len(bins), unit="bin") as pbar:
244 for i, chunk_bins_iter in enumerate(chunks(bins, chunk_size)):
245 chunk_bins = set(chunk_bins_iter)
246 logging.debug(f"chunk {i}: assessing quality of {len(chunk_bins)} bins")
247 bins_scored = assess_bins_quality(
248 bins=chunk_bins,
249 contig_to_kegg_counter=contig_to_kegg_counter,
250 contig_to_cds_count=contig_to_cds_count,
251 contig_to_aa_counter=contig_to_aa_counter,
252 contig_to_aa_length=contig_to_aa_length,
253 contamination_weight=contamination_weight,
254 postProcessor=postProcessor,
255 threads=threads,
256 )
257 pbar.update(len(bins_scored))
260def assess_bins_quality(
261 bins: Iterable[Bin],
262 contig_to_kegg_counter: Dict,
263 contig_to_cds_count: Dict,
264 contig_to_aa_counter: Dict,
265 contig_to_aa_length: Dict,
266 contamination_weight: float,
267 postProcessor: Optional[modelPostprocessing.modelProcessor] = None,
268 threads: int = 1,
269):
270 """
271 Assess the quality of bins.
273 This function assesses the quality of bins based on various criteria and assigns completeness and contamination scores.
274 This code is taken from checkm2 and adjusted
276 :param bins: List of bin objects.
277 :param contig_to_kegg_counter: Dictionary mapping contig names to KEGG counters.
278 :param contig_to_cds_count: Dictionary mapping contig names to CDS counts.
279 :param contig_to_aa_counter: Dictionary mapping contig names to amino acid counters.
280 :param contig_to_aa_length: Dictionary mapping contig names to amino acid lengths.
281 :param contamination_weight: Weight for contamination assessment.
282 :param postProcessor: A post-processor from checkm2
283 :param threads: Number of threads for parallel processing (default is 1).
284 """
285 if postProcessor is None:
286 postProcessor = modelPostprocessing.modelProcessor(threads)
288 metadata_df = get_bins_metadata_df(
289 bins, contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length
290 )
292 diamond_complete_results, ko_list_length = get_diamond_feature_per_bin_df(
293 bins, contig_to_kegg_counter
294 )
295 diamond_complete_results = diamond_complete_results.drop(columns=["Name"])
297 feature_vectors = pd.concat([metadata_df, diamond_complete_results], axis=1)
298 feature_vectors = feature_vectors.sort_values(by="Name")
300 # 4: Call general model & specific models and derive predictions"""
301 modelProc = modelProcessing.modelProcessor(threads)
303 vector_array = feature_vectors.iloc[:, 1:].values.astype(np.float)
305 logging.debug("Predicting completeness and contamination using the general model.")
306 general_results_comp, general_results_cont = modelProc.run_prediction_general(
307 vector_array
308 )
310 logging.debug("Predicting completeness using the specific model.")
311 specific_model_vector_len = (ko_list_length + len(metadata_df.columns)) - 1
313 # also retrieve scaled data for CSM calculations
314 specific_results_comp, scaled_features = modelProc.run_prediction_specific(
315 vector_array, specific_model_vector_len
316 )
318 logging.debug(
319 "Using cosine similarity to reference data to select an appropriate predictor model."
320 )
322 final_comp, final_cont, models_chosen, csm_array = (
323 postProcessor.calculate_general_specific_ratio(
324 vector_array[:, 20],
325 scaled_features,
326 general_results_comp,
327 general_results_cont,
328 specific_results_comp,
329 )
330 )
332 final_results = feature_vectors[["Name"]].copy()
333 final_results["Completeness"] = np.round(final_comp, 2)
334 final_results["Contamination"] = np.round(final_cont, 2)
336 for bin_obj in bins:
337 completeness = final_results.at[bin_obj.id, "Completeness"]
338 contamination = final_results.at[bin_obj.id, "Contamination"]
340 bin_obj.add_quality(completeness, contamination, contamination_weight)
341 return bins