Coverage for binette/bin_quality.py: 97%
250 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 14:36 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-14 14:36 +0000
1#!/usr/bin/env python3
2import gc
3import logging
4import os
5from collections import Counter, defaultdict
6from collections.abc import Iterable, Iterator
7from itertools import islice
9import joblib
10import numpy as np
11import pandas as pd
12from checkm2 import keggData
13from rich.progress import Progress
15from binette.bin_manager import Bin
17logger = logging.getLogger(__name__)
19# Suppress unnecessary TensorFlow warnings
20os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
21logging.getLogger("tensorflow").setLevel(logging.FATAL)
23# Lazy loaders for checkm2 components that import keras
24# These will only be imported when explicitly called
25_modelPostprocessing = None
26_modelProcessing = None
27_keras_initialized = False
30def _initialize_keras_environment():
31 """Initialize TensorFlow/Keras to ensure thread safety and memory management"""
32 global _keras_initialized
33 if not _keras_initialized:
34 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF warnings
36 try:
37 # Only import keras-related modules when needed
38 import tensorflow as tf
40 # Use a single thread for predictions to avoid thread contention
41 tf.config.threading.set_intra_op_parallelism_threads(1)
42 tf.config.threading.set_inter_op_parallelism_threads(1)
44 _keras_initialized = True
45 logger.debug("TensorFlow/Keras environment initialized")
46 except Exception as e:
47 logger.warning(f"Failed to fully initialize TensorFlow environment: {e!s}")
50def get_modelPostprocessing():
51 """Lazy load modelPostprocessing module only when needed"""
52 global _modelPostprocessing
53 if _modelPostprocessing is None:
54 # Initialize Keras environment
55 _initialize_keras_environment()
57 # Only import keras when absolutely needed
58 from checkm2 import modelPostprocessing
60 _modelPostprocessing = modelPostprocessing
61 return _modelPostprocessing
64def get_modelProcessing():
65 """Lazy load modelProcessing module only when needed"""
66 global _modelProcessing
67 if _modelProcessing is None:
68 # Initialize Keras environment
69 _initialize_keras_environment()
71 # Only import keras when absolutely needed
72 from checkm2 import modelProcessing
74 _modelProcessing = modelProcessing
76 return _modelProcessing
79def get_bins_metadata_df(
80 bins: list[Bin],
81 contig_to_cds_count: dict[str, int],
82 contig_to_aa_counter: dict[str, Counter],
83 contig_to_aa_length: dict[str, int],
84) -> pd.DataFrame:
85 """
86 Optimized: Generate a DataFrame containing metadata for a list of bins.
87 Handles contigs that appear in multiple bins.
88 """
89 metadata_order = keggData.KeggCalculator().return_proper_order("Metadata")
90 bin_keys = [b.contigs_key for b in bins]
92 # --- Pre-aggregate CDS and AA length ---
93 cds_per_bin = defaultdict(int)
94 aa_len_per_bin = defaultdict(int)
95 aa_counter_per_bin = defaultdict(Counter)
97 # map contigs → all bins they belong to
98 contig_to_bins = defaultdict(list)
99 for b in bins:
100 for c in b.contigs:
101 contig_to_bins[c].append(b.contigs_key)
103 # distribute CDS counts
104 for contig, cds in contig_to_cds_count.items():
105 for bin_key in contig_to_bins.get(contig, []):
106 cds_per_bin[bin_key] += cds
108 # distribute AA lengths
109 for contig, length in contig_to_aa_length.items():
110 for bin_key in contig_to_bins.get(contig, []):
111 aa_len_per_bin[bin_key] += length
113 # distribute AA counters
114 for contig, counter in contig_to_aa_counter.items():
115 for bin_key in contig_to_bins.get(contig, []):
116 aa_counter_per_bin[bin_key].update(counter)
118 # --- Build rows ---
119 rows = []
120 for key in bin_keys:
121 row = {
122 "Name": key,
123 "CDS": cds_per_bin.get(key, 0),
124 "AALength": aa_len_per_bin.get(key, 0),
125 }
126 row.update(aa_counter_per_bin.get(key, {}))
127 rows.append(row)
129 # --- Construct DataFrame directly ---
130 metadata_df = pd.DataFrame(rows).fillna(0)
132 # Ensure column order
133 all_cols = ["Name"] + metadata_order
134 for col in metadata_order:
135 if col not in metadata_df.columns:
136 metadata_df[col] = 0
138 metadata_df = metadata_df[all_cols].astype(dict.fromkeys(metadata_order, int))
139 metadata_df = metadata_df.set_index("Name", drop=False)
141 return metadata_df
144def get_diamond_feature_per_bin_df(
145 bins: list[Bin], contig_to_kegg_counter: dict[str, Counter]
146) -> tuple[pd.DataFrame, int]:
147 """
148 Optimized: Generate a DataFrame containing Diamond feature counts per bin,
149 including KEGG KO counts and completeness information for pathways, categories, and modules.
150 Handles contigs that may belong to multiple bins.
151 """
152 KeggCalc = keggData.KeggCalculator()
153 defaultKOs = KeggCalc.return_default_values_from_category("KO_Genes")
154 bin_keys = [b.contigs_key for b in bins]
156 # --- Build contig → bins mapping ---
157 contig_to_bins = defaultdict(list)
158 for b in bins:
159 for c in b.contigs:
160 contig_to_bins[c].append(b.contigs_key)
162 # --- Aggregate KO counters per bin ---
163 bin_to_ko_counter = {}
164 for bin_obj in bins:
165 bin_ko_counter = Counter()
166 for contig in bin_obj.contigs:
167 ko_counter = contig_to_kegg_counter.get(contig)
168 if ko_counter:
169 bin_ko_counter.update(ko_counter)
170 bin_to_ko_counter[bin_obj.contigs_key] = bin_ko_counter
172 # --- Build KO count DataFrame directly ---
173 ko_count_per_bin_df = (
174 pd.DataFrame.from_dict(bin_to_ko_counter, orient="index")
175 .reindex(bin_keys) # keep bin order
176 .fillna(0)
177 .astype(int)
178 )
180 # Ensure all defaultKOs exist
181 ko_count_per_bin_df = ko_count_per_bin_df.reindex(
182 columns=list(defaultKOs), fill_value=0
183 )
185 # ko_count_per_bin_df.index.name = "Name"
186 ko_count_per_bin_df["Name"] = ko_count_per_bin_df.index
188 # --- Calculate higher-level completeness ---
189 logger.debug("Calculating completeness of pathways, categories, and modules")
190 KO_pathways = calculate_KO_group(KeggCalc, "KO_Pathways", ko_count_per_bin_df)
191 KO_categories = calculate_KO_group(KeggCalc, "KO_Categories", ko_count_per_bin_df)
193 KO_modules = calculate_module_completeness(KeggCalc, ko_count_per_bin_df)
195 # --- Concatenate results ---
196 diamond_complete_results = pd.concat(
197 [ko_count_per_bin_df, KO_pathways, KO_modules, KO_categories], axis=1
198 )
200 return diamond_complete_results, len(defaultKOs)
203def calculate_KO_group(
204 KeggCalc: keggData.KeggCalculator, group: str, KO_gene_data: pd.DataFrame
205) -> pd.DataFrame:
206 """
207 Calculate the completeness of KEGG feature groups per bin.
209 :param KeggCalc: An instance of KeggCalculator containing KEGG mappings.
210 :param group: Feature group name (e.g., "KO_Pathways", "KO_Categories").
211 :param KO_gene_data: DataFrame containing KO counts per bin with last column "Name".
213 :return: DataFrame with completeness values for each feature vector in the group.
214 """
216 # last column is 'Name'
217 data = KO_gene_data.drop(columns=["Name"]).values
218 n_bins = data.shape[0]
220 # Build output DataFrame
221 ordered_entries = KeggCalc.return_default_values_from_category(group)
222 feature_vectors = list(ordered_entries.keys())
223 n_features = len(feature_vectors)
225 # Create empty numpy array for results
226 result = np.zeros((n_bins, n_features), dtype=float)
228 # Map Kegg_IDs to column indices in KO_gene_data
229 col_map = {ko: idx for idx, ko in enumerate(KO_gene_data.columns[:-1])}
231 for f_idx, vector in enumerate(feature_vectors):
232 # KOs belonging to this feature vector
233 kegg_ids = KeggCalc.path_category_mapping.loc[
234 KeggCalc.path_category_mapping[group] == vector, "Kegg_ID"
235 ].values
237 # Only keep KOs present in DataFrame columns
238 present_cols = [col_map[ko] for ko in kegg_ids if ko in col_map]
239 if not present_cols:
240 continue
242 # Presence/absence: values >1 -> 1
243 vals = data[:, present_cols]
244 vals[vals > 1] = 1
245 result[:, f_idx] = vals.sum(axis=1) / len(kegg_ids)
247 return pd.DataFrame(result, columns=feature_vectors, index=KO_gene_data.index)
250def calculate_module_completeness(
251 KeggCalc: keggData.KeggCalculator, KO_gene_data: pd.DataFrame
252) -> pd.DataFrame:
253 """
254 Compute module completeness per bin using NumPy for speed.
256 :param KeggCalc: An instance of KeggCalculator containing module definitions.
257 :param KO_gene_data: DataFrame containing KO counts per bin with last column "Name".
259 :return: DataFrame with completeness values for each module.
260 """
261 data = KO_gene_data.drop(columns=["Name"]).values
262 n_bins = data.shape[0]
264 modules = list(KeggCalc.module_definitions.keys())
265 n_modules = len(modules)
267 # Map KO names to column indices
268 col_map = {
269 ko: idx for idx, ko in enumerate(KO_gene_data.drop(columns=["Name"]).columns)
270 }
272 # Prepare result array
273 result = np.zeros((n_bins, n_modules), dtype=float)
275 for m_idx, module in enumerate(modules):
276 # Only keep KOs that exist in the DataFrame
278 module_kos = [ko for ko in KeggCalc.module_definitions[module] if ko in col_map]
279 if not module_kos:
280 continue
281 cols = [col_map[ko] for ko in module_kos]
283 vals = data[:, cols]
284 # vals[vals > 1] = 1 # presence/absence
286 result[:, m_idx] = vals.sum(axis=1) / len(KeggCalc.module_definitions[module])
288 return pd.DataFrame(result, columns=modules, index=KO_gene_data.index)
291def prepare_contig_sizes(contig_to_size: dict[int, int]) -> np.ndarray:
292 """
293 Prepare a numpy array of contig sizes for fast access.
295 :param contig_to_size: Dictionary mapping contig IDs to contig sizes.
297 :return: Numpy array where the index corresponds to the contig ID
298 and the value is the contig size.
299 """
300 max_id = max(contig_to_size)
301 contig_sizes = np.zeros(max_id + 1, dtype=np.int64)
302 for contig_id, size in contig_to_size.items():
303 contig_sizes[contig_id] = size
304 return contig_sizes
307def compute_N50(lengths: np.ndarray) -> int:
308 """
309 Compute the N50 value for a given set of contig lengths.
311 :param lengths: Numpy array of contig lengths.
313 :return: N50 value (contig length at which 50% of the genome is covered).
314 """
315 arr = np.sort(lengths)
316 half = arr.sum() / 2
317 csum = np.cumsum(arr)
318 return arr[np.searchsorted(csum, half)]
321def add_bin_size_and_N50(bins: Iterable[Bin], contig_to_size: dict[int, int]):
322 """
323 Add bin size and N50 metrics to a list of bin objects.
325 :param bins: List of bin objects.
326 :param contig_to_size: Dictionary mapping contig IDs to contig sizes.
328 :return: None. The bin objects are updated in place with size and N50.
329 """
330 # TODO use numpy array everywhere instead of contig_to_size
331 contig_sizes = prepare_contig_sizes(contig_to_size)
333 for bin_obj in bins:
334 lengths = contig_sizes[list(bin_obj.contigs)] # fast bulk lookup
335 total_len = lengths.sum()
336 n50 = compute_N50(lengths)
338 bin_obj.add_length(int(total_len))
339 bin_obj.add_N50(int(n50))
342def add_bin_coding_density(
343 bins: list[Bin], contig_to_coding_length: dict[int, int]
344) -> float | None:
345 """
346 Calculate the coding density of the given bins.
348 :param contig_to_coding_length: A dictionary mapping contig IDs to their total coding lengths.
350 :return: The coding density of the bin, or None if the length is not set or is zero.
351 """
352 for bin_obj in bins:
353 bin_obj.add_coding_density(contig_to_coding_length)
356def add_bin_metrics(
357 bins: list[Bin],
358 contig_info: dict,
359 contamination_weight: float,
360 threads: int = 1,
361 checkm2_batch_size: int = 500,
362 disable_progress_bar: bool = False,
363):
364 """
365 Add metrics to a Set of bins.
367 :param bins: Set of bin objects.
368 :param contig_info: Dictionary containing contig information.
369 :param contamination_weight: Weight for contamination assessment.
370 :param threads: Number of threads for parallel processing (default is 1).
371 If threads=1, all processing happens sequentially using one thread.
372 If threads>1, processing is parallelized across multiple processes.
373 The number of parallel workers will be approximately equal to threads.
374 :param checkm2_batch_size: Maximum number of bins to send to CheckM2 at once within each process
375 to control memory usage. This creates sub-batches
376 within each worker to manage CheckM2's memory consumption.
377 :param disable_progress_bar: Disable the progress bar if True.
379 :return: List of processed bin objects with quality metrics added.
380 """
381 if not bins:
382 logger.warning("No bins provided for quality assessment")
383 return []
385 bins_list = list(bins)
387 logger.info(
388 f"Assessing bin quality for {len(bins_list)} bins using {threads} threads"
389 )
391 # Extract data from contig_info
392 contig_to_kegg_counter = contig_info["contig_to_kegg_counter"]
393 contig_to_cds_count = contig_info["contig_to_cds_count"]
394 contig_to_aa_counter = contig_info["contig_to_aa_counter"]
395 contig_to_aa_length = contig_info["contig_to_aa_length"]
397 def _process_sequential():
398 """Helper function for sequential processing"""
399 modelPostprocessing = get_modelPostprocessing()
400 postProcessor = modelPostprocessing.modelProcessor(threads)
401 return assess_bins_quality(
402 bins=bins_list,
403 contig_to_kegg_counter=contig_to_kegg_counter,
404 contig_to_cds_count=contig_to_cds_count,
405 contig_to_aa_counter=contig_to_aa_counter,
406 contig_to_aa_length=contig_to_aa_length,
407 contamination_weight=contamination_weight,
408 postProcessor=postProcessor,
409 threads=threads,
410 checkm2_batch_size=checkm2_batch_size,
411 )
413 min_bins_per_chunk = checkm2_batch_size * 6
415 if threads == 1 or len(bins_list) <= min_bins_per_chunk * 2:
416 if len(bins_list) <= min_bins_per_chunk:
417 logger.info(
418 f"Only {len(bins_list)} bins (≤ {min_bins_per_chunk}). Using sequential processing to avoid multiprocessing overhead."
419 )
420 return _process_sequential()
422 # For parallel processing, use joblib
423 # Calculate number of chunks ensuring each chunk has sufficient work
424 max_possible_chunks = len(bins_list) // min_bins_per_chunk
425 n_chunks = max(1, min(threads * 2, max_possible_chunks))
427 n_jobs = min(threads, n_chunks)
428 # Use balanced chunking to distribute work evenly across available threads
429 chunks_list = balanced_chunks(bins_list, n_chunks)
431 logger.info(
432 f"Created {len(chunks_list)} balanced chunks for {n_jobs} parallel jobs"
433 )
434 logger.info(
435 f"Configuration: {len(bins_list)} bins, {threads} threads, {n_chunks} chunks, batch_size={checkm2_batch_size}"
436 )
437 for idx, chunk in enumerate(chunks_list):
438 logger.debug(f"Chunk {idx + 1}/{len(chunks_list)} contains {len(chunk)} bins")
440 # Define a simple function to process a chunk
441 def process_chunk(chunk_bins):
442 # Initialize TensorFlow/Keras environment for this subprocess
443 _initialize_keras_environment()
445 # Determine optimal thread count for this worker
446 # For best efficiency, we allocate a portion of total threads to each worker
447 # Math.ceil(total_threads / n_jobs) would be most aggressive
448 # But we use 1 thread per worker as the default to avoid oversubscription
449 worker_threads = max(
450 1, threads // (2 * n_jobs)
451 ) # Conservative thread allocation
453 # Create local processor instance
454 modelPostprocessing = get_modelPostprocessing()
455 local_postProcessor = modelPostprocessing.modelProcessor(worker_threads)
457 # Process bins with nested chunking for memory management
458 return assess_bins_quality(
459 bins=chunk_bins,
460 contig_to_kegg_counter=contig_to_kegg_counter,
461 contig_to_cds_count=contig_to_cds_count,
462 contig_to_aa_counter=contig_to_aa_counter,
463 contig_to_aa_length=contig_to_aa_length,
464 contamination_weight=contamination_weight,
465 postProcessor=local_postProcessor,
466 threads=worker_threads, # Use allocated threads in each worker
467 checkm2_batch_size=checkm2_batch_size,
468 )
470 # Process chunks in parallel using joblib
471 with Progress(disable=disable_progress_bar) as progress:
472 task = progress.add_task("Assessing bin quality", total=len(bins_list))
474 # Use joblib for parallelization
475 results = joblib.Parallel(n_jobs=n_jobs)(
476 joblib.delayed(process_chunk)(chunk) for chunk in chunks_list
477 )
479 # Combine results
480 all_bins = []
481 for chunk_result in results:
482 all_bins.extend(chunk_result)
483 progress.update(task, advance=len(chunk_result))
485 return all_bins
488def chunks(iterable, size: int) -> Iterator[tuple]:
489 """
490 Generate adjacent chunks of data from an iterable.
492 :param iterable: The iterable to be divided into chunks.
493 :param size: The size of each chunk.
494 :return: An iterator that produces tuples of elements in chunks.
495 """
496 it = iter(iterable)
497 return iter(lambda: tuple(islice(it, size)), ())
500def balanced_chunks(items: list, num_chunks: int) -> list[list]:
501 """
502 Distribute items into balanced chunks with more even size distribution.
504 :param items: List of items to chunk.
505 :param num_chunks: Number of chunks to create.
506 :return: List of chunks with balanced sizes.
507 """
508 if num_chunks <= 0:
509 return [items]
510 if num_chunks >= len(items):
511 return [[item] for item in items]
513 # Calculate base size and remainder
514 base_size = len(items) // num_chunks
515 remainder = len(items) % num_chunks
517 chunks_list = []
518 start_idx = 0
520 for i in range(num_chunks):
521 # Some chunks get an extra item to distribute the remainder
522 chunk_size = base_size + (1 if i < remainder else 0)
523 chunk = items[start_idx : start_idx + chunk_size]
524 if chunk: # Only add non-empty chunks
525 chunks_list.append(chunk)
526 start_idx += chunk_size
528 return chunks_list
531def assess_bins_quality(
532 bins: Iterable[Bin],
533 contig_to_kegg_counter: dict,
534 contig_to_cds_count: dict,
535 contig_to_aa_counter: dict,
536 contig_to_aa_length: dict,
537 contamination_weight: float,
538 checkm2_batch_size: int,
539 postProcessor=None,
540 threads: int = 1,
541):
542 """
543 Assess the quality of bins.
545 This function assesses the quality of bins based on various criteria and assigns completeness and contamination scores.
546 This code is taken from checkm2 and adjusted
548 :param bins: List of bin objects.
549 :param contig_to_kegg_counter: Dictionary mapping contig names to KEGG counters.
550 :param contig_to_cds_count: Dictionary mapping contig names to CDS counts.
551 :param contig_to_aa_counter: Dictionary mapping contig names to amino acid counters.
552 :param contig_to_aa_length: Dictionary mapping contig names to amino acid lengths.
553 :param contamination_weight: Weight for contamination assessment.
554 :param postProcessor: A post-processor from checkm2
555 :param threads: Number of threads for parallel processing (default is 1).
556 :param checkm2_batch_size: Maximum number of bins to process in a single CheckM2 call.
557 """
558 if postProcessor is None:
559 modelPostprocessing = get_modelPostprocessing()
560 postProcessor = modelPostprocessing.modelProcessor(threads)
562 bins_list = list(bins)
564 # If we have fewer bins than the batch size, process them all at once
565 if len(bins_list) <= checkm2_batch_size:
566 return _assess_bins_quality_batch(
567 bins_list,
568 contig_to_kegg_counter,
569 contig_to_cds_count,
570 contig_to_aa_counter,
571 contig_to_aa_length,
572 contamination_weight,
573 postProcessor,
574 threads,
575 )
577 # Split bins into smaller batches for memory management
578 logger.debug(
579 f"Splitting {len(bins_list)} bins into batches of {checkm2_batch_size} for CheckM2 processing"
580 )
582 all_processed_bins = []
583 batch_chunks = list(chunks(bins_list, checkm2_batch_size))
585 for i, batch_bins in enumerate(batch_chunks):
586 logger.debug(
587 f"Processing CheckM2 batch {i + 1}/{len(batch_chunks)} with {len(batch_bins)} bins"
588 )
590 # Process this batch
591 processed_batch = _assess_bins_quality_batch(
592 batch_bins,
593 contig_to_kegg_counter,
594 contig_to_cds_count,
595 contig_to_aa_counter,
596 contig_to_aa_length,
597 contamination_weight,
598 postProcessor,
599 threads,
600 )
602 all_processed_bins.extend(processed_batch)
604 # Force garbage collection between batches to free memory
605 gc.collect()
607 return all_processed_bins
610def _assess_bins_quality_batch(
611 bins: list[Bin],
612 contig_to_kegg_counter: dict,
613 contig_to_cds_count: dict,
614 contig_to_aa_counter: dict,
615 contig_to_aa_length: dict,
616 contamination_weight: float,
617 postProcessor,
618 threads: int,
619):
620 """
621 Assess the quality of a batch of bins (internal function).
623 This function processes a single batch of bins through CheckM2.
624 It's called by assess_bins_quality for each batch when memory management is needed.
625 """
627 metadata_df = get_bins_metadata_df(
628 bins, contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length
629 )
631 diamond_complete_results, ko_list_length = get_diamond_feature_per_bin_df(
632 bins, contig_to_kegg_counter
633 )
634 diamond_complete_results = diamond_complete_results.drop(columns=["Name"])
636 feature_vectors = pd.concat([metadata_df, diamond_complete_results], axis=1)
637 feature_vectors = feature_vectors.sort_values(by="Name")
639 # Create mapping from bin name to bin object for easy lookup
640 bin_name_to_bin = {bin_obj.contigs_key: bin_obj for bin_obj in bins}
642 # 4: Call general model & specific models and derive predictions"""
643 modelProcessing = get_modelProcessing()
644 modelProc = modelProcessing.modelProcessor(threads)
646 vector_array = feature_vectors.iloc[:, 1:].values.astype(float)
648 logger.debug("Predicting completeness and contamination using the general model")
649 general_results_comp, general_results_cont = modelProc.run_prediction_general(
650 vector_array
651 )
653 logger.debug("Predicting completeness using the specific model")
654 specific_model_vector_len = (ko_list_length + len(metadata_df.columns)) - 1
656 # also retrieve scaled data for CSM calculations
657 specific_results_comp, scaled_features = modelProc.run_prediction_specific(
658 vector_array, specific_model_vector_len
659 )
661 logger.debug(
662 "Using cosine similarity to reference data to select an appropriate predictor model."
663 )
665 final_comp, final_cont, models_chosen, csm_array = (
666 postProcessor.calculate_general_specific_ratio(
667 vector_array[:, 20],
668 scaled_features,
669 general_results_comp,
670 general_results_cont,
671 specific_results_comp,
672 )
673 )
675 # Directly iterate through results arrays and lookup corresponding bins
676 for bin_name, completeness, contamination, chosen_model in zip(
677 feature_vectors["Name"],
678 np.round(final_comp, 2),
679 np.round(final_cont, 2),
680 models_chosen,
681 strict=True,
682 ):
683 bin_obj = bin_name_to_bin[bin_name]
684 bin_obj.add_quality(completeness, contamination, contamination_weight)
685 bin_obj.add_model(chosen_model)
687 return bins