|
1 | 1 | use std::{collections::VecDeque, os::unix::process::parent_id};
|
2 | 2 |
|
| 3 | +use anyhow::Ok; |
3 | 4 | use num_traits::Float;
|
4 | 5 | use rand::{Rng, seq::SliceRandom};
|
5 | 6 | use rand_chacha::ChaCha8Rng;
|
@@ -663,4 +664,114 @@ impl LeidenOptimizer {
|
663 | 664 | }
|
664 | 665 | }
|
665 | 666 | }
|
| 667 | + |
| 668 | + fn merge_nodes_constrained<N, G, P>( |
| 669 | + &mut self, |
| 670 | + partitions: &mut [P], |
| 671 | + layer_weights: &[N], |
| 672 | + consider_comms: ConsiderComms, |
| 673 | + constrained_partition: &P, |
| 674 | + max_comm_size: Option<usize>, |
| 675 | + ) -> anyhow::Result<N> |
| 676 | + where |
| 677 | + N: FloatOpsTS + 'static, |
| 678 | + G: NetworkGrouping + Clone + Default, |
| 679 | + P: VertexPartition<N, G>, |
| 680 | + { |
| 681 | + let nb_layers = partitions.len(); |
| 682 | + if nb_layers == 0 { |
| 683 | + return Ok(N::from(-1.0).unwrap()); |
| 684 | + } |
| 685 | + |
| 686 | + let n = partitions[0].node_count(); |
| 687 | + |
| 688 | + // Check all partitions have same number of nodes |
| 689 | + for partition in partitions.iter() { |
| 690 | + if partition.node_count() != n { |
| 691 | + return Err(anyhow::anyhow!( |
| 692 | + "Number of nodes are not equal for all graphs." |
| 693 | + )); |
| 694 | + } |
| 695 | + } |
| 696 | + |
| 697 | + let mut total_improv = N::zero(); |
| 698 | + |
| 699 | + // Establish vertex order and shuffle it |
| 700 | + let mut vertex_order: Vec<usize> = (0..n).collect(); |
| 701 | + vertex_order.shuffle(&mut self.rng); |
| 702 | + |
| 703 | + // Get constrained communities structure |
| 704 | + let constrained_comms = constrained_partition.get_communities(); |
| 705 | + |
| 706 | + let mut comm_added = vec![false; partitions[0].community_count()]; |
| 707 | + let mut comms = Vec::new(); |
| 708 | + |
| 709 | + |
| 710 | + for v in vertex_order { |
| 711 | + let v_comm = partitions[0].membership(v); |
| 712 | + |
| 713 | + if partitions[0].cnodes(v_comm) == 1 { |
| 714 | + for &comm in &comms { |
| 715 | + if comm < comm_added.len() { |
| 716 | + comm_added[comm] = false; |
| 717 | + } |
| 718 | + } |
| 719 | + } |
| 720 | + |
| 721 | + comms.clear(); |
| 722 | + |
| 723 | + self.collect_constrained_candidate_communities(v, partitions, constrained_partition, &constrained_comms, consider_comms, &mut comms, &mut comm_added); |
| 724 | + |
| 725 | + let mut max_comm = v_comm; |
| 726 | + let mut max_improv = if let Some(max_size) = max_comm_size { |
| 727 | + if max_size < partitions[0].csize(v_comm) { |
| 728 | + N::from(f64::NEG_INFINITY).unwrap() |
| 729 | + } else { |
| 730 | + N::zero() |
| 731 | + } |
| 732 | + } else { |
| 733 | + N::zero() |
| 734 | + }; |
| 735 | + |
| 736 | + let v_size = 1; |
| 737 | + |
| 738 | + for &comm in &comms { |
| 739 | + if let Some(max_size) = max_comm_size { |
| 740 | + if max_size < partitions[0].csize(comm) + v_size { |
| 741 | + continue; |
| 742 | + } |
| 743 | + } |
| 744 | + |
| 745 | + let mut possible_improv = N::zero(); |
| 746 | + |
| 747 | + for (layer, partition) in partitions.iter().enumerate() { |
| 748 | + let layer_imrpov = partition.diff_move(v, comm); |
| 749 | + possible_improv += layer_weights[layer] * layer_imrpov; |
| 750 | + } |
| 751 | + |
| 752 | + if possible_improv >= max_improv { |
| 753 | + max_comm = comm; |
| 754 | + max_improv = possible_improv; |
| 755 | + } |
| 756 | + } |
| 757 | + |
| 758 | + |
| 759 | + if max_comm != v_comm { |
| 760 | + total_improv += max_improv; |
| 761 | + |
| 762 | + for partition in partitions.iter_mut() { |
| 763 | + partition.move_node(v, max_comm); |
| 764 | + } |
| 765 | + } |
| 766 | + |
| 767 | + } |
| 768 | + |
| 769 | + partitions[0].renumber_communities(); |
| 770 | + let membership = partitions[0].membership_vector(); |
| 771 | + for partition in partitions.iter_mut().skip(1) { |
| 772 | + partition.set_membership(&membership); |
| 773 | + } |
| 774 | + |
| 775 | + Ok(total_improv) |
| 776 | + } |
666 | 777 | }
|
0 commit comments