11from __future__ import annotations
22
33import collections .abc as cabc
4+ import enum
45import os
56import stat
67import sys
2324 from .core import Parameter
2425 from .shell_completion import CompletionItem
2526
27+ ParamTypeValue = t .TypeVar ("ParamTypeValue" )
28+
2629
2730class ParamType :
2831 """Represents the type of a parameter. Validates and converts values
@@ -86,10 +89,10 @@ def __call__(
8689 if value is not None :
8790 return self .convert (value , param , ctx )
8891
89- def get_metavar (self , param : Parameter ) -> str | None :
92+ def get_metavar (self , param : Parameter , ctx : Context ) -> str | None :
9093 """Returns the metavar default for this param if it provides one."""
9194
92- def get_missing_message (self , param : Parameter ) -> str | None :
95+ def get_missing_message (self , param : Parameter , ctx : Context | None ) -> str | None :
9396 """Optionally might return extra information about a missing
9497 parameter.
9598
@@ -227,29 +230,35 @@ def __repr__(self) -> str:
227230 return "STRING"
228231
229232
230- class Choice (ParamType ):
233+ class Choice (ParamType , t . Generic [ ParamTypeValue ] ):
231234 """The choice type allows a value to be checked against a fixed set
232- of supported values. All of these values have to be strings.
233-
234- You should only pass a list or tuple of choices. Other iterables
235- (like generators) may lead to surprising results.
235+ of supported values.
236236
237- The resulting value will always be one of the originally passed choices
238- regardless of ``case_sensitive`` or any ``ctx.token_normalize_func``
239- being specified.
237+ You may pass any iterable value which will be converted to a tuple
238+ and thus will only be iterated once.
240239
241- See :ref:`choice-opts` for an example.
240+ The resulting value will always be one of the originally passed choices.
241+ See :meth:`normalize_choice` for more info on the mapping of strings
242+ to choices. See :ref:`choice-opts` for an example.
242243
243244 :param case_sensitive: Set to false to make choices case
244245 insensitive. Defaults to true.
246+
247+ .. versionchanged:: 8.2.0
248+ Non-``str`` ``choices`` are now supported. It can additionally be any
249+ iterable. Before you were not recommended to pass anything but a list or
250+ tuple.
251+
252+ .. versionadded:: 8.2.0
253+ Choice normalization can be overridden via :meth:`normalize_choice`.
245254 """
246255
247256 name = "choice"
248257
249258 def __init__ (
250- self , choices : cabc .Sequence [ str ], case_sensitive : bool = True
259+ self , choices : cabc .Iterable [ ParamTypeValue ], case_sensitive : bool = True
251260 ) -> None :
252- self .choices = choices
261+ self .choices : cabc . Sequence [ ParamTypeValue ] = tuple ( choices )
253262 self .case_sensitive = case_sensitive
254263
255264 def to_info_dict (self ) -> dict [str , t .Any ]:
@@ -258,14 +267,54 @@ def to_info_dict(self) -> dict[str, t.Any]:
258267 info_dict ["case_sensitive" ] = self .case_sensitive
259268 return info_dict
260269
261- def get_metavar (self , param : Parameter ) -> str :
270+ def _normalized_mapping (
271+ self , ctx : Context | None = None
272+ ) -> cabc .Mapping [ParamTypeValue , str ]:
273+ """
274+ Returns mapping where keys are the original choices and the values are
275+ the normalized values that are accepted via the command line.
276+
277+ This is a simple wrapper around :meth:`normalize_choice`, use that
278+ instead which is supported.
279+ """
280+ return {
281+ choice : self .normalize_choice (
282+ choice = choice ,
283+ ctx = ctx ,
284+ )
285+ for choice in self .choices
286+ }
287+
288+ def normalize_choice (self , choice : ParamTypeValue , ctx : Context | None ) -> str :
289+ """
290+ Normalize a choice value, used to map a passed string to a choice.
291+ Each choice must have a unique normalized value.
292+
293+ By default uses :meth:`Context.token_normalize_func` and if not case
294+ sensitive, convert it to a casefolded value.
295+
296+ .. versionadded:: 8.2.0
297+ """
298+ normed_value = choice .name if isinstance (choice , enum .Enum ) else str (choice )
299+
300+ if ctx is not None and ctx .token_normalize_func is not None :
301+ normed_value = ctx .token_normalize_func (normed_value )
302+
303+ if not self .case_sensitive :
304+ normed_value = normed_value .casefold ()
305+
306+ return normed_value
307+
308+ def get_metavar (self , param : Parameter , ctx : Context ) -> str | None :
262309 if param .param_type_name == "option" and not param .show_choices : # type: ignore
263310 choice_metavars = [
264311 convert_type (type (choice )).name .upper () for choice in self .choices
265312 ]
266313 choices_str = "|" .join ([* dict .fromkeys (choice_metavars )])
267314 else :
268- choices_str = "|" .join ([str (i ) for i in self .choices ])
315+ choices_str = "|" .join (
316+ [str (i ) for i in self ._normalized_mapping (ctx = ctx ).values ()]
317+ )
269318
270319 # Use curly braces to indicate a required argument.
271320 if param .required and param .param_type_name == "argument" :
@@ -274,46 +323,48 @@ def get_metavar(self, param: Parameter) -> str:
274323 # Use square braces to indicate an option or optional argument.
275324 return f"[{ choices_str } ]"
276325
277- def get_missing_message (self , param : Parameter ) -> str :
278- return _ ("Choose from:\n \t {choices}" ).format (choices = ",\n \t " .join (self .choices ))
326+ def get_missing_message (self , param : Parameter , ctx : Context | None ) -> str :
327+ """
328+ Message shown when no choice is passed.
329+
330+ .. versionchanged:: 8.2.0 Added ``ctx`` argument.
331+ """
332+ return _ ("Choose from:\n \t {choices}" ).format (
333+ choices = ",\n \t " .join (self ._normalized_mapping (ctx = ctx ).values ())
334+ )
279335
280336 def convert (
281337 self , value : t .Any , param : Parameter | None , ctx : Context | None
282- ) -> t .Any :
283- # Match through normalization and case sensitivity
284- # first do token_normalize_func, then lowercase
285- # preserve original `value` to produce an accurate message in
286- # `self.fail`
287- normed_value = value
288- normed_choices = {choice : choice for choice in self .choices }
289-
290- if ctx is not None and ctx .token_normalize_func is not None :
291- normed_value = ctx .token_normalize_func (value )
292- normed_choices = {
293- ctx .token_normalize_func (normed_choice ): original
294- for normed_choice , original in normed_choices .items ()
295- }
296-
297- if not self .case_sensitive :
298- normed_value = normed_value .casefold ()
299- normed_choices = {
300- normed_choice .casefold (): original
301- for normed_choice , original in normed_choices .items ()
302- }
303-
304- if normed_value in normed_choices :
305- return normed_choices [normed_value ]
338+ ) -> ParamTypeValue :
339+ """
340+ For a given value from the parser, normalize it and find its
341+ matching normalized value in the list of choices. Then return the
342+ matched "original" choice.
343+ """
344+ normed_value = self .normalize_choice (choice = value , ctx = ctx )
345+ normalized_mapping = self ._normalized_mapping (ctx = ctx )
306346
307- self .fail (self .get_invalid_choice_message (value ), param , ctx )
347+ try :
348+ return next (
349+ original
350+ for original , normalized in normalized_mapping .items ()
351+ if normalized == normed_value
352+ )
353+ except StopIteration :
354+ self .fail (
355+ self .get_invalid_choice_message (value = value , ctx = ctx ),
356+ param = param ,
357+ ctx = ctx ,
358+ )
308359
309- def get_invalid_choice_message (self , value : t .Any ) -> str :
360+ def get_invalid_choice_message (self , value : t .Any , ctx : Context | None ) -> str :
310361 """Get the error message when the given choice is invalid.
311362
312363 :param value: The invalid value.
313364
314365 .. versionadded:: 8.2
315366 """
316- choices_str = ", " .join (map (repr , self .choices ))
367+ choices_str = ", " .join (map (repr , self ._normalized_mapping ( ctx = ctx ). values () ))
317368 return ngettext (
318369 "{value!r} is not {choice}." ,
319370 "{value!r} is not one of {choices}." ,
@@ -382,7 +433,7 @@ def to_info_dict(self) -> dict[str, t.Any]:
382433 info_dict ["formats" ] = self .formats
383434 return info_dict
384435
385- def get_metavar (self , param : Parameter ) -> str :
436+ def get_metavar (self , param : Parameter , ctx : Context ) -> str | None :
386437 return f"[{ '|' .join (self .formats )} ]"
387438
388439 def _try_to_convert_date (self , value : t .Any , format : str ) -> datetime | None :
0 commit comments