4
4
from functools import lru_cache
5
5
from operator import itemgetter
6
6
from os import path
7
- from re import error
8
- from typing import Any , Union
7
+ from typing import Any , Union , Tuple , Optional
9
8
10
9
import dask
11
10
import dask .dataframe as dd
26
25
def clean_country (
27
26
df : Union [pd .DataFrame , dd .DataFrame ],
28
27
column : str ,
29
- input_format : str = "auto" ,
28
+ input_format : Union [ str , Tuple [ str , ...]] = "auto" ,
30
29
output_format : str = "name" ,
31
30
fuzzy_dist : int = 0 ,
32
31
strict : bool = False ,
@@ -55,6 +54,10 @@ def clean_country(
55
54
- 'alpha-3': alpha-3 code ('USA')
56
55
- 'numeric': numeric code (840)
57
56
57
+ Can also be a tuple containing any combination of input formats,
58
+ for example to clean a column containing alpha-2 and numeric
59
+ codes set input_format to ('alpha-2', 'numeric').
60
+
58
61
(default: 'auto')
59
62
output_format
60
63
The desired ISO 3166 format of the country:
@@ -112,14 +115,7 @@ def clean_country(
112
115
1 US United States
113
116
"""
114
117
# pylint: disable=too-many-arguments
115
-
116
- input_formats = {"auto" , "name" , "official" , "alpha-2" , "alpha-3" , "numeric" }
117
118
output_formats = {"name" , "official" , "alpha-2" , "alpha-3" , "numeric" }
118
- if input_format not in input_formats :
119
- raise ValueError (
120
- f'input_format { input_format } is invalid, it needs to be one of "auto", '
121
- '"name", "official", "alpha-2", "alpha-3" or "numeric'
122
- )
123
119
if output_format not in output_formats :
124
120
raise ValueError (
125
121
f'output_format { output_format } is invalid, it needs to be "name", '
@@ -130,6 +126,7 @@ def clean_country(
130
126
"can't do fuzzy matching while strict mode is enabled, "
131
127
"set strict=False for fuzzy matching or fuzzy_dist=0 for strict matching"
132
128
)
129
+ input_formats = _input_format_to_tuple (input_format )
133
130
134
131
# convert to dask
135
132
df = to_dask (df )
@@ -140,7 +137,8 @@ def clean_country(
140
137
# amount of different codes to produce the report
141
138
df ["clean_code_tup" ] = df [column ].map_partitions (
142
139
lambda srs : [
143
- _format_country (x , input_format , output_format , fuzzy_dist , strict , errors ) for x in srs
140
+ _format_country (x , input_formats , output_format , fuzzy_dist , strict , errors )
141
+ for x in srs
144
142
],
145
143
meta = object ,
146
144
)
@@ -168,7 +166,9 @@ def clean_country(
168
166
169
167
170
168
def validate_country (
171
- x : Union [str , int , pd .Series ], input_format : str = "auto" , strict : bool = True
169
+ x : Union [str , int , pd .Series ],
170
+ input_format : Union [str , Tuple [str , ...]] = "auto" ,
171
+ strict : bool = True ,
172
172
) -> Union [bool , pd .Series ]:
173
173
"""
174
174
Validate country names.
@@ -188,6 +188,10 @@ def validate_country(
188
188
- 'alpha-3': alpha-3 code ('USA')
189
189
- 'numeric': numeric code (840)
190
190
191
+ Can also be a tuple containing any combination of input formats,
192
+ for example to clean a column containing alpha-2 and numeric
193
+ codes set input_format to ('alpha-2', 'numeric').
194
+
191
195
(default: 'auto')
192
196
strict
193
197
If True, matching for input formats 'name' and 'official' are done by
@@ -207,18 +211,18 @@ def validate_country(
207
211
1 False
208
212
Name: country, dtype: bool
209
213
"""
210
-
214
+ input_formats = _input_format_to_tuple ( input_format )
211
215
if isinstance (x , pd .Series ):
212
216
x = x .astype (str ).str .lower ().str .strip ()
213
- return x .apply (_check_country , args = (input_format , strict , False ))
217
+ return x .apply (_check_country , args = (input_formats , strict , False ))
214
218
215
219
x = str (x ).lower ().strip ()
216
- return _check_country (x , input_format , strict , False )
220
+ return _check_country (x , input_formats , strict , False )
217
221
218
222
219
223
def _format_country (
220
224
val : Any ,
221
- input_format : str ,
225
+ input_formats : Tuple [ str , ...] ,
222
226
output_format : str ,
223
227
fuzzy_dist : int ,
224
228
strict : bool ,
@@ -241,9 +245,13 @@ def _format_country(
241
245
# could not be parsed) or "success" (a successful parse of the value).
242
246
243
247
country = str (val ).lower ().strip ()
244
- result_index , status = _check_country (country , input_format , strict , True )
248
+ result_index , status = _check_country (country , input_formats , strict , True )
245
249
246
- if fuzzy_dist > 0 and status == "unknown" and input_format in ("auto" , "name" , "official" ):
250
+ if (
251
+ fuzzy_dist > 0
252
+ and status == "unknown"
253
+ and ("name" in input_formats or "official" in input_formats )
254
+ ):
247
255
result_index , status = _check_fuzzy_dist (country , fuzzy_dist )
248
256
249
257
if status == "null" :
@@ -264,16 +272,16 @@ def _format_country(
264
272
265
273
266
274
@lru_cache (maxsize = 2 ** 20 )
267
- def _check_country (country : str , input_format : str , strict : bool , clean : bool ) -> Any :
275
+ def _check_country (country : str , input_formats : Tuple [ str , ...] , strict : bool , clean : bool ) -> Any :
268
276
"""
269
277
Finds the index of the given country in the DATA dataframe.
270
278
271
279
Parameters
272
280
----------
273
281
country
274
282
string containing the country value being cleaned
275
- input_format
276
- the ISO 3166 input format of the country
283
+ input_formats
284
+ Tuple containing potential ISO 3166 input formats of the country
277
285
strict
278
286
If True, for input types "name" and "offical" the function looks for a direct match
279
287
in the DATA dataframe. If False, the country input is searched for a regex match.
@@ -284,19 +292,18 @@ def _check_country(country: str, input_format: str, strict: bool, clean: bool) -
284
292
if country in NULL_VALUES :
285
293
return (None , "null" ) if clean else False
286
294
287
- if input_format == "auto" :
288
- input_format = _get_format_from_name (country )
295
+ country_format = _get_format_from_name (country )
296
+ input_format = _get_format_if_allowed (country_format , input_formats )
297
+ if not input_format :
298
+ return (None , "unknown" ) if clean else False
289
299
290
300
if strict and input_format == "regex" :
291
301
for form in ("name" , "official" ):
292
- try :
293
- ind = DATA [
294
- DATA [form ].str .contains (f"^{ country } $" , flags = re .IGNORECASE , na = False )
295
- ].index
296
- if np .size (ind ) > 0 :
297
- return (ind [0 ], "success" ) if clean else True
298
- except error :
299
- return (None , "unknown" ) if clean else False
302
+ ind = DATA [
303
+ DATA [form ].str .contains (f"^{ re .escape (country )} $" , flags = re .IGNORECASE , na = False )
304
+ ].index
305
+ if np .size (ind ) > 0 :
306
+ return (ind [0 ], "success" ) if clean else True
300
307
301
308
elif not strict and input_format in ("regex" , "name" , "official" ):
302
309
for index , country_regex in enumerate (REGEXES ):
@@ -305,7 +312,9 @@ def _check_country(country: str, input_format: str, strict: bool, clean: bool) -
305
312
306
313
else :
307
314
ind = DATA [
308
- DATA [input_format ].str .contains (f"^{ country } $" , flags = re .IGNORECASE , na = False )
315
+ DATA [input_format ].str .contains (
316
+ f"^{ re .escape (country )} $" , flags = re .IGNORECASE , na = False
317
+ )
309
318
].index
310
319
if np .size (ind ) > 0 :
311
320
return (ind [0 ], "success" ) if clean else True
@@ -346,3 +355,45 @@ def _get_format_from_name(name: str) -> str:
346
355
return "numeric"
347
356
except ValueError :
348
357
return "alpha-2" if len (name ) == 2 else "alpha-3" if len (name ) == 3 else "regex"
358
+
359
+
360
+ def _get_format_if_allowed (input_format : str , allowed_formats : Tuple [str , ...]) -> Optional [str ]:
361
+ """
362
+ Returns the input format if it's an allowed format.
363
+ "regex" input_format is only returned if "name" and "official are
364
+ allowed. This is because when strict = True and input_format = "regex"
365
+ both the "name" and "official" columns in the DATA dataframe are checked.
366
+ """
367
+ if input_format == "regex" :
368
+ if "name" in allowed_formats and "official" in allowed_formats :
369
+ return "regex"
370
+
371
+ return (
372
+ "name"
373
+ if "name" in allowed_formats
374
+ else "official"
375
+ if "official" in allowed_formats
376
+ else None
377
+ )
378
+
379
+ return input_format if input_format in allowed_formats else None
380
+
381
+
382
+ def _input_format_to_tuple (input_format : Union [str , Tuple [str , ...]]) -> Tuple [str , ...]:
383
+ """
384
+ Converts a string input format to a tuple of allowed input formats and raises an error
385
+ if an input format is not valid.
386
+ """
387
+ input_formats = {"auto" , "name" , "official" , "alpha-2" , "alpha-3" , "numeric" }
388
+ if isinstance (input_format , str ):
389
+ if input_format == "auto" :
390
+ return ("name" , "official" , "alpha-2" , "alpha-3" , "numeric" )
391
+ input_format = (input_format ,)
392
+
393
+ for fmt in input_format :
394
+ if fmt not in input_formats :
395
+ raise ValueError (
396
+ f'input_format { fmt } is invalid, it needs to be one of "auto", '
397
+ '"name", "official", "alpha-2", "alpha-3" or "numeric'
398
+ )
399
+ return input_format
0 commit comments