Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI subcommand union and alias support #380

Merged
Prev Previous commit
Next Next commit
Revert and bring this fix in separately.
  • Loading branch information
kschwab committed Sep 2, 2024
commit cc42d7edec64a3ec9bc42ec9673e1efff9e20091
37 changes: 15 additions & 22 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,33 +1248,26 @@ def _load_env_vars(

return self

def _get_merge_parsed_list_types(
self, parsed_list: list[str], field_name: str
) -> tuple[Optional[type], Optional[type]]:
merge_type = self._cli_dict_args.get(field_name, list)
if (
merge_type is list
or not origin_is_union(get_origin(merge_type))
or not any(
type_
for type_ in get_args(merge_type)
if type_ is not type(None) and get_origin(type_) not in (dict, Mapping)
)
):
inferred_type = merge_type
else:
inferred_type = list if parsed_list and (len(parsed_list) > 1 or parsed_list[0].startswith('[')) else str

return merge_type, inferred_type

def _merge_parsed_list(self, parsed_list: list[str], field_name: str) -> str:
try:
merged_list: list[str] = []
is_last_consumed_a_value = False
merge_type, inferred_type = self._get_merge_parsed_list_types(parsed_list, field_name)
merge_type = self._cli_dict_args.get(field_name, list)
if (
merge_type is list
or not origin_is_union(get_origin(merge_type))
or not any(
type_
for type_ in get_args(merge_type)
if type_ is not type(None) and get_origin(type_) not in (dict, Mapping)
)
):
inferred_type = merge_type
else:
inferred_type = (
list if parsed_list and (len(parsed_list) > 1 or parsed_list[0].startswith('[')) else str
)
for val in parsed_list:
if not isinstance(val, str):
break
val = val.strip()
if val.startswith('[') and val.endswith(']'):
val = val[1:-1].strip()
Expand Down