@@ -121,7 +121,7 @@ def any_row_contains_duplicate_values(n_processes: int, frame: pd.DataFrame) ->
121
121
is_duplicated = pool .map (nunique , iter (numpy_array ))
122
122
return any (is_duplicated )
123
123
124
- def concatenate_matrices (n_processes : int , input_ids : tuple [str ], matrices : Iterable [pd .DataFrame ]) \
124
+ def concatenate_matrices (n_processes : int , input_ids : tuple [str ], matrices : Iterable [pd .DataFrame ], align_to : pd . Index | None ) \
125
125
-> tuple [dict [str , pd .DataFrame ], pd .DataFrame | None , dict [str , pd .core .dtypes .dtypes .Dtype ]]:
126
126
"""
127
127
Merge matrices by combining columns that have the same name.
@@ -131,12 +131,13 @@ def concatenate_matrices(n_processes: int, input_ids: tuple[str], matrices: Iter
131
131
column_names = set (column_name for var in matrices for column_name in var )
132
132
logger .debug ('Trying to concatenate columns: %s.' , "," .join (column_names ))
133
133
if not column_names :
134
- return {}, None
134
+ return {}, pd . DataFrame ( index = align_to )
135
135
conflicts , concatenated_matrix = \
136
136
split_conflicts_and_concatenated_columns (n_processes ,
137
137
input_ids ,
138
138
matrices ,
139
- column_names )
139
+ column_names ,
140
+ align_to )
140
141
concatenated_matrix = cast_to_writeable_dtype (concatenated_matrix )
141
142
conflicts = {conflict_name : cast_to_writeable_dtype (conflict_df )
142
143
for conflict_name , conflict_df in conflicts .items ()}
@@ -152,7 +153,8 @@ def get_first_non_na_value_vector(df):
152
153
def split_conflicts_and_concatenated_columns (n_processes : int ,
153
154
input_ids : tuple [str ],
154
155
matrices : Iterable [pd .DataFrame ],
155
- column_names : Iterable [str ]) -> \
156
+ column_names : Iterable [str ],
157
+ align_to : pd .Index | None = None ) -> \
156
158
tuple [dict [str , pd .DataFrame ], pd .DataFrame ]:
157
159
"""
158
160
Retrieve columns with the same name from a list of dataframes which are
@@ -166,19 +168,21 @@ def split_conflicts_and_concatenated_columns(n_processes: int,
166
168
for column_name in column_names :
167
169
columns = [var [column_name ] for var in matrices if column_name in var ]
168
170
assert columns , "Some columns should have been found."
169
- concatenated_columns = pd .concat (columns , axis = 1 , join = "outer" )
171
+ concatenated_columns = pd .concat (columns , axis = 1 , join = "outer" , sort = False )
170
172
if any_row_contains_duplicate_values (n_processes , concatenated_columns ):
171
173
concatenated_columns .columns = input_ids
174
+ if align_to is not None :
175
+ concatenated_columns = concatenated_columns .reindex (align_to , copy = False )
172
176
conflicts [f'conflict_{ column_name } ' ] = concatenated_columns
173
177
else :
174
178
unique_values = get_first_non_na_value_vector (concatenated_columns )
175
- # concatenated_columns.fillna(method='bfill', axis=1).iloc[:, 0]
176
179
concatenated_matrix .append (unique_values )
177
- if concatenated_matrix :
178
- concatenated_matrix = pd .concat (concatenated_matrix , join = "outer" , axis = 1 )
179
- else :
180
- concatenated_matrix = pd .DataFrame ()
181
-
180
+ if not concatenated_matrix :
181
+ return conflicts , pd .DataFrame (index = align_to )
182
+ concatenated_matrix = pd .concat (concatenated_matrix , join = "outer" ,
183
+ axis = 1 , sort = False )
184
+ if align_to is not None :
185
+ concatenated_matrix = concatenated_matrix .reindex (align_to , copy = False )
182
186
return conflicts , concatenated_matrix
183
187
184
188
def cast_to_writeable_dtype (result : pd .DataFrame ) -> pd .DataFrame :
@@ -213,15 +217,17 @@ def split_conflicts_modalities(n_processes: int, input_ids: tuple[str], samples:
213
217
matrices_to_parse = ("var" , "obs" )
214
218
for matrix_name in matrices_to_parse :
215
219
matrices = [getattr (sample , matrix_name ) for sample in samples ]
216
- conflicts , concatenated_matrix = concatenate_matrices (n_processes , input_ids , matrices )
217
-
220
+ output_index = getattr (output , matrix_name ).index
221
+ align_to = output_index if matrix_name == "var" else None
222
+ conflicts , concatenated_matrix = concatenate_matrices (n_processes , input_ids , matrices , align_to )
223
+ if concatenated_matrix .empty :
224
+ concatenated_matrix .index = output_index
218
225
# Write the conflicts to the output
219
- matrix_index = getattr (output , matrix_name ).index
220
226
for conflict_name , conflict_data in conflicts .items ():
221
- getattr (output , f"{ matrix_name } m" )[conflict_name ] = conflict_data . reindex ( matrix_index )
227
+ getattr (output , f"{ matrix_name } m" )[conflict_name ] = conflict_data
222
228
223
229
# Set other annotation matrices in the output
224
- setattr (output , matrix_name , pd . DataFrame () if concatenated_matrix is None else concatenated_matrix )
230
+ setattr (output , matrix_name , concatenated_matrix )
225
231
226
232
return output
227
233
0 commit comments