Coverage for binette/bin_quality.py: 97%

250 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-14 14:36 +0000

1#!/usr/bin/env python3 

2import gc 

3import logging 

4import os 

5from collections import Counter, defaultdict 

6from collections.abc import Iterable, Iterator 

7from itertools import islice 

8 

9import joblib 

10import numpy as np 

11import pandas as pd 

12from checkm2 import keggData 

13from rich.progress import Progress 

14 

15from binette.bin_manager import Bin 

16 

17logger = logging.getLogger(__name__) 

18 

19# Suppress unnecessary TensorFlow warnings 

20os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 

21logging.getLogger("tensorflow").setLevel(logging.FATAL) 

22 

23# Lazy loaders for checkm2 components that import keras 

24# These will only be imported when explicitly called 

25_modelPostprocessing = None 

26_modelProcessing = None 

27_keras_initialized = False 

28 

29 

30def _initialize_keras_environment(): 

31 """Initialize TensorFlow/Keras to ensure thread safety and memory management""" 

32 global _keras_initialized 

33 if not _keras_initialized: 

34 os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TF warnings 

35 

36 try: 

37 # Only import keras-related modules when needed 

38 import tensorflow as tf 

39 

40 # Use a single thread for predictions to avoid thread contention 

41 tf.config.threading.set_intra_op_parallelism_threads(1) 

42 tf.config.threading.set_inter_op_parallelism_threads(1) 

43 

44 _keras_initialized = True 

45 logger.debug("TensorFlow/Keras environment initialized") 

46 except Exception as e: 

47 logger.warning(f"Failed to fully initialize TensorFlow environment: {e!s}") 

48 

49 

50def get_modelPostprocessing(): 

51 """Lazy load modelPostprocessing module only when needed""" 

52 global _modelPostprocessing 

53 if _modelPostprocessing is None: 

54 # Initialize Keras environment 

55 _initialize_keras_environment() 

56 

57 # Only import keras when absolutely needed 

58 from checkm2 import modelPostprocessing 

59 

60 _modelPostprocessing = modelPostprocessing 

61 return _modelPostprocessing 

62 

63 

64def get_modelProcessing(): 

65 """Lazy load modelProcessing module only when needed""" 

66 global _modelProcessing 

67 if _modelProcessing is None: 

68 # Initialize Keras environment 

69 _initialize_keras_environment() 

70 

71 # Only import keras when absolutely needed 

72 from checkm2 import modelProcessing 

73 

74 _modelProcessing = modelProcessing 

75 

76 return _modelProcessing 

77 

78 

79def get_bins_metadata_df( 

80 bins: list[Bin], 

81 contig_to_cds_count: dict[str, int], 

82 contig_to_aa_counter: dict[str, Counter], 

83 contig_to_aa_length: dict[str, int], 

84) -> pd.DataFrame: 

85 """ 

86 Optimized: Generate a DataFrame containing metadata for a list of bins. 

87 Handles contigs that appear in multiple bins. 

88 """ 

89 metadata_order = keggData.KeggCalculator().return_proper_order("Metadata") 

90 bin_keys = [b.contigs_key for b in bins] 

91 

92 # --- Pre-aggregate CDS and AA length --- 

93 cds_per_bin = defaultdict(int) 

94 aa_len_per_bin = defaultdict(int) 

95 aa_counter_per_bin = defaultdict(Counter) 

96 

97 # map contigs → all bins they belong to 

98 contig_to_bins = defaultdict(list) 

99 for b in bins: 

100 for c in b.contigs: 

101 contig_to_bins[c].append(b.contigs_key) 

102 

103 # distribute CDS counts 

104 for contig, cds in contig_to_cds_count.items(): 

105 for bin_key in contig_to_bins.get(contig, []): 

106 cds_per_bin[bin_key] += cds 

107 

108 # distribute AA lengths 

109 for contig, length in contig_to_aa_length.items(): 

110 for bin_key in contig_to_bins.get(contig, []): 

111 aa_len_per_bin[bin_key] += length 

112 

113 # distribute AA counters 

114 for contig, counter in contig_to_aa_counter.items(): 

115 for bin_key in contig_to_bins.get(contig, []): 

116 aa_counter_per_bin[bin_key].update(counter) 

117 

118 # --- Build rows --- 

119 rows = [] 

120 for key in bin_keys: 

121 row = { 

122 "Name": key, 

123 "CDS": cds_per_bin.get(key, 0), 

124 "AALength": aa_len_per_bin.get(key, 0), 

125 } 

126 row.update(aa_counter_per_bin.get(key, {})) 

127 rows.append(row) 

128 

129 # --- Construct DataFrame directly --- 

130 metadata_df = pd.DataFrame(rows).fillna(0) 

131 

132 # Ensure column order 

133 all_cols = ["Name"] + metadata_order 

134 for col in metadata_order: 

135 if col not in metadata_df.columns: 

136 metadata_df[col] = 0 

137 

138 metadata_df = metadata_df[all_cols].astype(dict.fromkeys(metadata_order, int)) 

139 metadata_df = metadata_df.set_index("Name", drop=False) 

140 

141 return metadata_df 

142 

143 

144def get_diamond_feature_per_bin_df( 

145 bins: list[Bin], contig_to_kegg_counter: dict[str, Counter] 

146) -> tuple[pd.DataFrame, int]: 

147 """ 

148 Optimized: Generate a DataFrame containing Diamond feature counts per bin, 

149 including KEGG KO counts and completeness information for pathways, categories, and modules. 

150 Handles contigs that may belong to multiple bins. 

151 """ 

152 KeggCalc = keggData.KeggCalculator() 

153 defaultKOs = KeggCalc.return_default_values_from_category("KO_Genes") 

154 bin_keys = [b.contigs_key for b in bins] 

155 

156 # --- Build contig → bins mapping --- 

157 contig_to_bins = defaultdict(list) 

158 for b in bins: 

159 for c in b.contigs: 

160 contig_to_bins[c].append(b.contigs_key) 

161 

162 # --- Aggregate KO counters per bin --- 

163 bin_to_ko_counter = {} 

164 for bin_obj in bins: 

165 bin_ko_counter = Counter() 

166 for contig in bin_obj.contigs: 

167 ko_counter = contig_to_kegg_counter.get(contig) 

168 if ko_counter: 

169 bin_ko_counter.update(ko_counter) 

170 bin_to_ko_counter[bin_obj.contigs_key] = bin_ko_counter 

171 

172 # --- Build KO count DataFrame directly --- 

173 ko_count_per_bin_df = ( 

174 pd.DataFrame.from_dict(bin_to_ko_counter, orient="index") 

175 .reindex(bin_keys) # keep bin order 

176 .fillna(0) 

177 .astype(int) 

178 ) 

179 

180 # Ensure all defaultKOs exist 

181 ko_count_per_bin_df = ko_count_per_bin_df.reindex( 

182 columns=list(defaultKOs), fill_value=0 

183 ) 

184 

185 # ko_count_per_bin_df.index.name = "Name" 

186 ko_count_per_bin_df["Name"] = ko_count_per_bin_df.index 

187 

188 # --- Calculate higher-level completeness --- 

189 logger.debug("Calculating completeness of pathways, categories, and modules") 

190 KO_pathways = calculate_KO_group(KeggCalc, "KO_Pathways", ko_count_per_bin_df) 

191 KO_categories = calculate_KO_group(KeggCalc, "KO_Categories", ko_count_per_bin_df) 

192 

193 KO_modules = calculate_module_completeness(KeggCalc, ko_count_per_bin_df) 

194 

195 # --- Concatenate results --- 

196 diamond_complete_results = pd.concat( 

197 [ko_count_per_bin_df, KO_pathways, KO_modules, KO_categories], axis=1 

198 ) 

199 

200 return diamond_complete_results, len(defaultKOs) 

201 

202 

203def calculate_KO_group( 

204 KeggCalc: keggData.KeggCalculator, group: str, KO_gene_data: pd.DataFrame 

205) -> pd.DataFrame: 

206 """ 

207 Calculate the completeness of KEGG feature groups per bin. 

208 

209 :param KeggCalc: An instance of KeggCalculator containing KEGG mappings. 

210 :param group: Feature group name (e.g., "KO_Pathways", "KO_Categories"). 

211 :param KO_gene_data: DataFrame containing KO counts per bin with last column "Name". 

212 

213 :return: DataFrame with completeness values for each feature vector in the group. 

214 """ 

215 

216 # last column is 'Name' 

217 data = KO_gene_data.drop(columns=["Name"]).values 

218 n_bins = data.shape[0] 

219 

220 # Build output DataFrame 

221 ordered_entries = KeggCalc.return_default_values_from_category(group) 

222 feature_vectors = list(ordered_entries.keys()) 

223 n_features = len(feature_vectors) 

224 

225 # Create empty numpy array for results 

226 result = np.zeros((n_bins, n_features), dtype=float) 

227 

228 # Map Kegg_IDs to column indices in KO_gene_data 

229 col_map = {ko: idx for idx, ko in enumerate(KO_gene_data.columns[:-1])} 

230 

231 for f_idx, vector in enumerate(feature_vectors): 

232 # KOs belonging to this feature vector 

233 kegg_ids = KeggCalc.path_category_mapping.loc[ 

234 KeggCalc.path_category_mapping[group] == vector, "Kegg_ID" 

235 ].values 

236 

237 # Only keep KOs present in DataFrame columns 

238 present_cols = [col_map[ko] for ko in kegg_ids if ko in col_map] 

239 if not present_cols: 

240 continue 

241 

242 # Presence/absence: values >1 -> 1 

243 vals = data[:, present_cols] 

244 vals[vals > 1] = 1 

245 result[:, f_idx] = vals.sum(axis=1) / len(kegg_ids) 

246 

247 return pd.DataFrame(result, columns=feature_vectors, index=KO_gene_data.index) 

248 

249 

250def calculate_module_completeness( 

251 KeggCalc: keggData.KeggCalculator, KO_gene_data: pd.DataFrame 

252) -> pd.DataFrame: 

253 """ 

254 Compute module completeness per bin using NumPy for speed. 

255 

256 :param KeggCalc: An instance of KeggCalculator containing module definitions. 

257 :param KO_gene_data: DataFrame containing KO counts per bin with last column "Name". 

258 

259 :return: DataFrame with completeness values for each module. 

260 """ 

261 data = KO_gene_data.drop(columns=["Name"]).values 

262 n_bins = data.shape[0] 

263 

264 modules = list(KeggCalc.module_definitions.keys()) 

265 n_modules = len(modules) 

266 

267 # Map KO names to column indices 

268 col_map = { 

269 ko: idx for idx, ko in enumerate(KO_gene_data.drop(columns=["Name"]).columns) 

270 } 

271 

272 # Prepare result array 

273 result = np.zeros((n_bins, n_modules), dtype=float) 

274 

275 for m_idx, module in enumerate(modules): 

276 # Only keep KOs that exist in the DataFrame 

277 

278 module_kos = [ko for ko in KeggCalc.module_definitions[module] if ko in col_map] 

279 if not module_kos: 

280 continue 

281 cols = [col_map[ko] for ko in module_kos] 

282 

283 vals = data[:, cols] 

284 # vals[vals > 1] = 1 # presence/absence 

285 

286 result[:, m_idx] = vals.sum(axis=1) / len(KeggCalc.module_definitions[module]) 

287 

288 return pd.DataFrame(result, columns=modules, index=KO_gene_data.index) 

289 

290 

291def prepare_contig_sizes(contig_to_size: dict[int, int]) -> np.ndarray: 

292 """ 

293 Prepare a numpy array of contig sizes for fast access. 

294 

295 :param contig_to_size: Dictionary mapping contig IDs to contig sizes. 

296 

297 :return: Numpy array where the index corresponds to the contig ID 

298 and the value is the contig size. 

299 """ 

300 max_id = max(contig_to_size) 

301 contig_sizes = np.zeros(max_id + 1, dtype=np.int64) 

302 for contig_id, size in contig_to_size.items(): 

303 contig_sizes[contig_id] = size 

304 return contig_sizes 

305 

306 

307def compute_N50(lengths: np.ndarray) -> int: 

308 """ 

309 Compute the N50 value for a given set of contig lengths. 

310 

311 :param lengths: Numpy array of contig lengths. 

312 

313 :return: N50 value (contig length at which 50% of the genome is covered). 

314 """ 

315 arr = np.sort(lengths) 

316 half = arr.sum() / 2 

317 csum = np.cumsum(arr) 

318 return arr[np.searchsorted(csum, half)] 

319 

320 

321def add_bin_size_and_N50(bins: Iterable[Bin], contig_to_size: dict[int, int]): 

322 """ 

323 Add bin size and N50 metrics to a list of bin objects. 

324 

325 :param bins: List of bin objects. 

326 :param contig_to_size: Dictionary mapping contig IDs to contig sizes. 

327 

328 :return: None. The bin objects are updated in place with size and N50. 

329 """ 

330 # TODO use numpy array everywhere instead of contig_to_size 

331 contig_sizes = prepare_contig_sizes(contig_to_size) 

332 

333 for bin_obj in bins: 

334 lengths = contig_sizes[list(bin_obj.contigs)] # fast bulk lookup 

335 total_len = lengths.sum() 

336 n50 = compute_N50(lengths) 

337 

338 bin_obj.add_length(int(total_len)) 

339 bin_obj.add_N50(int(n50)) 

340 

341 

342def add_bin_coding_density( 

343 bins: list[Bin], contig_to_coding_length: dict[int, int] 

344) -> float | None: 

345 """ 

346 Calculate the coding density of the given bins. 

347 

348 :param contig_to_coding_length: A dictionary mapping contig IDs to their total coding lengths. 

349 

350 :return: The coding density of the bin, or None if the length is not set or is zero. 

351 """ 

352 for bin_obj in bins: 

353 bin_obj.add_coding_density(contig_to_coding_length) 

354 

355 

356def add_bin_metrics( 

357 bins: list[Bin], 

358 contig_info: dict, 

359 contamination_weight: float, 

360 threads: int = 1, 

361 checkm2_batch_size: int = 500, 

362 disable_progress_bar: bool = False, 

363): 

364 """ 

365 Add metrics to a Set of bins. 

366 

367 :param bins: Set of bin objects. 

368 :param contig_info: Dictionary containing contig information. 

369 :param contamination_weight: Weight for contamination assessment. 

370 :param threads: Number of threads for parallel processing (default is 1). 

371 If threads=1, all processing happens sequentially using one thread. 

372 If threads>1, processing is parallelized across multiple processes. 

373 The number of parallel workers will be approximately equal to threads. 

374 :param checkm2_batch_size: Maximum number of bins to send to CheckM2 at once within each process 

375 to control memory usage. This creates sub-batches 

376 within each worker to manage CheckM2's memory consumption. 

377 :param disable_progress_bar: Disable the progress bar if True. 

378 

379 :return: List of processed bin objects with quality metrics added. 

380 """ 

381 if not bins: 

382 logger.warning("No bins provided for quality assessment") 

383 return [] 

384 

385 bins_list = list(bins) 

386 

387 logger.info( 

388 f"Assessing bin quality for {len(bins_list)} bins using {threads} threads" 

389 ) 

390 

391 # Extract data from contig_info 

392 contig_to_kegg_counter = contig_info["contig_to_kegg_counter"] 

393 contig_to_cds_count = contig_info["contig_to_cds_count"] 

394 contig_to_aa_counter = contig_info["contig_to_aa_counter"] 

395 contig_to_aa_length = contig_info["contig_to_aa_length"] 

396 

397 def _process_sequential(): 

398 """Helper function for sequential processing""" 

399 modelPostprocessing = get_modelPostprocessing() 

400 postProcessor = modelPostprocessing.modelProcessor(threads) 

401 return assess_bins_quality( 

402 bins=bins_list, 

403 contig_to_kegg_counter=contig_to_kegg_counter, 

404 contig_to_cds_count=contig_to_cds_count, 

405 contig_to_aa_counter=contig_to_aa_counter, 

406 contig_to_aa_length=contig_to_aa_length, 

407 contamination_weight=contamination_weight, 

408 postProcessor=postProcessor, 

409 threads=threads, 

410 checkm2_batch_size=checkm2_batch_size, 

411 ) 

412 

413 min_bins_per_chunk = checkm2_batch_size * 6 

414 

415 if threads == 1 or len(bins_list) <= min_bins_per_chunk * 2: 

416 if len(bins_list) <= min_bins_per_chunk: 

417 logger.info( 

418 f"Only {len(bins_list)} bins (≤ {min_bins_per_chunk}). Using sequential processing to avoid multiprocessing overhead." 

419 ) 

420 return _process_sequential() 

421 

422 # For parallel processing, use joblib 

423 # Calculate number of chunks ensuring each chunk has sufficient work 

424 max_possible_chunks = len(bins_list) // min_bins_per_chunk 

425 n_chunks = max(1, min(threads * 2, max_possible_chunks)) 

426 

427 n_jobs = min(threads, n_chunks) 

428 # Use balanced chunking to distribute work evenly across available threads 

429 chunks_list = balanced_chunks(bins_list, n_chunks) 

430 

431 logger.info( 

432 f"Created {len(chunks_list)} balanced chunks for {n_jobs} parallel jobs" 

433 ) 

434 logger.info( 

435 f"Configuration: {len(bins_list)} bins, {threads} threads, {n_chunks} chunks, batch_size={checkm2_batch_size}" 

436 ) 

437 for idx, chunk in enumerate(chunks_list): 

438 logger.debug(f"Chunk {idx + 1}/{len(chunks_list)} contains {len(chunk)} bins") 

439 

440 # Define a simple function to process a chunk 

441 def process_chunk(chunk_bins): 

442 # Initialize TensorFlow/Keras environment for this subprocess 

443 _initialize_keras_environment() 

444 

445 # Determine optimal thread count for this worker 

446 # For best efficiency, we allocate a portion of total threads to each worker 

447 # Math.ceil(total_threads / n_jobs) would be most aggressive 

448 # But we use 1 thread per worker as the default to avoid oversubscription 

449 worker_threads = max( 

450 1, threads // (2 * n_jobs) 

451 ) # Conservative thread allocation 

452 

453 # Create local processor instance 

454 modelPostprocessing = get_modelPostprocessing() 

455 local_postProcessor = modelPostprocessing.modelProcessor(worker_threads) 

456 

457 # Process bins with nested chunking for memory management 

458 return assess_bins_quality( 

459 bins=chunk_bins, 

460 contig_to_kegg_counter=contig_to_kegg_counter, 

461 contig_to_cds_count=contig_to_cds_count, 

462 contig_to_aa_counter=contig_to_aa_counter, 

463 contig_to_aa_length=contig_to_aa_length, 

464 contamination_weight=contamination_weight, 

465 postProcessor=local_postProcessor, 

466 threads=worker_threads, # Use allocated threads in each worker 

467 checkm2_batch_size=checkm2_batch_size, 

468 ) 

469 

470 # Process chunks in parallel using joblib 

471 with Progress(disable=disable_progress_bar) as progress: 

472 task = progress.add_task("Assessing bin quality", total=len(bins_list)) 

473 

474 # Use joblib for parallelization 

475 results = joblib.Parallel(n_jobs=n_jobs)( 

476 joblib.delayed(process_chunk)(chunk) for chunk in chunks_list 

477 ) 

478 

479 # Combine results 

480 all_bins = [] 

481 for chunk_result in results: 

482 all_bins.extend(chunk_result) 

483 progress.update(task, advance=len(chunk_result)) 

484 

485 return all_bins 

486 

487 

488def chunks(iterable, size: int) -> Iterator[tuple]: 

489 """ 

490 Generate adjacent chunks of data from an iterable. 

491 

492 :param iterable: The iterable to be divided into chunks. 

493 :param size: The size of each chunk. 

494 :return: An iterator that produces tuples of elements in chunks. 

495 """ 

496 it = iter(iterable) 

497 return iter(lambda: tuple(islice(it, size)), ()) 

498 

499 

500def balanced_chunks(items: list, num_chunks: int) -> list[list]: 

501 """ 

502 Distribute items into balanced chunks with more even size distribution. 

503 

504 :param items: List of items to chunk. 

505 :param num_chunks: Number of chunks to create. 

506 :return: List of chunks with balanced sizes. 

507 """ 

508 if num_chunks <= 0: 

509 return [items] 

510 if num_chunks >= len(items): 

511 return [[item] for item in items] 

512 

513 # Calculate base size and remainder 

514 base_size = len(items) // num_chunks 

515 remainder = len(items) % num_chunks 

516 

517 chunks_list = [] 

518 start_idx = 0 

519 

520 for i in range(num_chunks): 

521 # Some chunks get an extra item to distribute the remainder 

522 chunk_size = base_size + (1 if i < remainder else 0) 

523 chunk = items[start_idx : start_idx + chunk_size] 

524 if chunk: # Only add non-empty chunks 

525 chunks_list.append(chunk) 

526 start_idx += chunk_size 

527 

528 return chunks_list 

529 

530 

531def assess_bins_quality( 

532 bins: Iterable[Bin], 

533 contig_to_kegg_counter: dict, 

534 contig_to_cds_count: dict, 

535 contig_to_aa_counter: dict, 

536 contig_to_aa_length: dict, 

537 contamination_weight: float, 

538 checkm2_batch_size: int, 

539 postProcessor=None, 

540 threads: int = 1, 

541): 

542 """ 

543 Assess the quality of bins. 

544 

545 This function assesses the quality of bins based on various criteria and assigns completeness and contamination scores. 

546 This code is taken from checkm2 and adjusted 

547 

548 :param bins: List of bin objects. 

549 :param contig_to_kegg_counter: Dictionary mapping contig names to KEGG counters. 

550 :param contig_to_cds_count: Dictionary mapping contig names to CDS counts. 

551 :param contig_to_aa_counter: Dictionary mapping contig names to amino acid counters. 

552 :param contig_to_aa_length: Dictionary mapping contig names to amino acid lengths. 

553 :param contamination_weight: Weight for contamination assessment. 

554 :param postProcessor: A post-processor from checkm2 

555 :param threads: Number of threads for parallel processing (default is 1). 

556 :param checkm2_batch_size: Maximum number of bins to process in a single CheckM2 call. 

557 """ 

558 if postProcessor is None: 

559 modelPostprocessing = get_modelPostprocessing() 

560 postProcessor = modelPostprocessing.modelProcessor(threads) 

561 

562 bins_list = list(bins) 

563 

564 # If we have fewer bins than the batch size, process them all at once 

565 if len(bins_list) <= checkm2_batch_size: 

566 return _assess_bins_quality_batch( 

567 bins_list, 

568 contig_to_kegg_counter, 

569 contig_to_cds_count, 

570 contig_to_aa_counter, 

571 contig_to_aa_length, 

572 contamination_weight, 

573 postProcessor, 

574 threads, 

575 ) 

576 

577 # Split bins into smaller batches for memory management 

578 logger.debug( 

579 f"Splitting {len(bins_list)} bins into batches of {checkm2_batch_size} for CheckM2 processing" 

580 ) 

581 

582 all_processed_bins = [] 

583 batch_chunks = list(chunks(bins_list, checkm2_batch_size)) 

584 

585 for i, batch_bins in enumerate(batch_chunks): 

586 logger.debug( 

587 f"Processing CheckM2 batch {i + 1}/{len(batch_chunks)} with {len(batch_bins)} bins" 

588 ) 

589 

590 # Process this batch 

591 processed_batch = _assess_bins_quality_batch( 

592 batch_bins, 

593 contig_to_kegg_counter, 

594 contig_to_cds_count, 

595 contig_to_aa_counter, 

596 contig_to_aa_length, 

597 contamination_weight, 

598 postProcessor, 

599 threads, 

600 ) 

601 

602 all_processed_bins.extend(processed_batch) 

603 

604 # Force garbage collection between batches to free memory 

605 gc.collect() 

606 

607 return all_processed_bins 

608 

609 

610def _assess_bins_quality_batch( 

611 bins: list[Bin], 

612 contig_to_kegg_counter: dict, 

613 contig_to_cds_count: dict, 

614 contig_to_aa_counter: dict, 

615 contig_to_aa_length: dict, 

616 contamination_weight: float, 

617 postProcessor, 

618 threads: int, 

619): 

620 """ 

621 Assess the quality of a batch of bins (internal function). 

622 

623 This function processes a single batch of bins through CheckM2. 

624 It's called by assess_bins_quality for each batch when memory management is needed. 

625 """ 

626 

627 metadata_df = get_bins_metadata_df( 

628 bins, contig_to_cds_count, contig_to_aa_counter, contig_to_aa_length 

629 ) 

630 

631 diamond_complete_results, ko_list_length = get_diamond_feature_per_bin_df( 

632 bins, contig_to_kegg_counter 

633 ) 

634 diamond_complete_results = diamond_complete_results.drop(columns=["Name"]) 

635 

636 feature_vectors = pd.concat([metadata_df, diamond_complete_results], axis=1) 

637 feature_vectors = feature_vectors.sort_values(by="Name") 

638 

639 # Create mapping from bin name to bin object for easy lookup 

640 bin_name_to_bin = {bin_obj.contigs_key: bin_obj for bin_obj in bins} 

641 

642 # 4: Call general model & specific models and derive predictions""" 

643 modelProcessing = get_modelProcessing() 

644 modelProc = modelProcessing.modelProcessor(threads) 

645 

646 vector_array = feature_vectors.iloc[:, 1:].values.astype(float) 

647 

648 logger.debug("Predicting completeness and contamination using the general model") 

649 general_results_comp, general_results_cont = modelProc.run_prediction_general( 

650 vector_array 

651 ) 

652 

653 logger.debug("Predicting completeness using the specific model") 

654 specific_model_vector_len = (ko_list_length + len(metadata_df.columns)) - 1 

655 

656 # also retrieve scaled data for CSM calculations 

657 specific_results_comp, scaled_features = modelProc.run_prediction_specific( 

658 vector_array, specific_model_vector_len 

659 ) 

660 

661 logger.debug( 

662 "Using cosine similarity to reference data to select an appropriate predictor model." 

663 ) 

664 

665 final_comp, final_cont, models_chosen, csm_array = ( 

666 postProcessor.calculate_general_specific_ratio( 

667 vector_array[:, 20], 

668 scaled_features, 

669 general_results_comp, 

670 general_results_cont, 

671 specific_results_comp, 

672 ) 

673 ) 

674 

675 # Directly iterate through results arrays and lookup corresponding bins 

676 for bin_name, completeness, contamination, chosen_model in zip( 

677 feature_vectors["Name"], 

678 np.round(final_comp, 2), 

679 np.round(final_cont, 2), 

680 models_chosen, 

681 strict=True, 

682 ): 

683 bin_obj = bin_name_to_bin[bin_name] 

684 bin_obj.add_quality(completeness, contamination, contamination_weight) 

685 bin_obj.add_model(chosen_model) 

686 

687 return bins