Skip to content

Commit 4ece17a

Browse files
authored
Expand Choice token normalization + make generic (#2796)
1 parent 8a47580 commit 4ece17a

File tree

10 files changed

+239
-71
lines changed

10 files changed

+239
-71
lines changed

CHANGES.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,16 @@ Unreleased
8888
- Parameters cannot be required nor prompted or an error is raised.
8989
- A warning will be printed when something deprecated is used.
9090

91-
- Add a ``catch_exceptions`` parameter to :class:`CliRunner`. If
92-
``catch_exceptions`` is not passed to :meth:`CliRunner.invoke`,
91+
- Add a ``catch_exceptions`` parameter to :class:``CliRunner``. If
92+
``catch_exceptions`` is not passed to :meth:``CliRunner.invoke``,
9393
the value from :class:`CliRunner`. :issue:`2817` :pr:`2818`
9494
- ``Option.flag_value`` will no longer have a default value set based on
9595
``Option.default`` if ``Option.is_flag`` is ``False``. This results in
9696
``Option.default`` not needing to implement `__bool__`. :pr:`2829`
9797
- Incorrect ``click.edit`` typing has been corrected. :pr:`2804`
98+
- :class:``Choice`` is now generic and supports any iterable value.
99+
This allows you to use enums and other non-``str`` values. :pr:`2796`
100+
:issue:`605`
98101

99102
Version 8.1.8
100103
-------------

docs/options.rst

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,22 @@ In that case you can use :class:`Choice` type. It can be instantiated
375375
with a list of valid values. The originally passed choice will be returned,
376376
not the str passed on the command line. Token normalization functions and
377377
``case_sensitive=False`` can cause the two to be different but still match.
378+
:meth:`Choice.normalize_choice` for more info.
378379

379380
Example:
380381

381382
.. click:example::
382383
384+
import enum
385+
386+
class HashType(enum.Enum):
387+
MD5 = 'MD5'
388+
SHA1 = 'SHA1'
389+
383390
@click.command()
384391
@click.option('--hash-type',
385-
type=click.Choice(['MD5', 'SHA1'], case_sensitive=False))
386-
def digest(hash_type):
392+
type=click.Choice(HashType, case_sensitive=False))
393+
def digest(hash_type: HashType):
387394
click.echo(hash_type)
388395

389396
What it looks like:
@@ -398,15 +405,16 @@ What it looks like:
398405
println()
399406
invoke(digest, args=['--help'])
400407

401-
Only pass the choices as list or tuple. Other iterables (like
402-
generators) may lead to unexpected results.
408+
Since version 8.2.0 any iterable may be passed to :class:`Choice`, here
409+
an ``Enum`` is used which will result in all enum values to be valid
410+
choices.
403411

404412
Choices work with options that have ``multiple=True``. If a ``default``
405413
value is given with ``multiple=True``, it should be a list or tuple of
406414
valid choices.
407415

408-
Choices should be unique after considering the effects of
409-
``case_sensitive`` and any specified token normalization function.
416+
Choices should be unique after normalization, see
417+
:meth:`Choice.normalize_choice` for more info.
410418

411419
.. versionchanged:: 7.1
412420
The resulting value from an option will always be one of the

src/click/core.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,11 +2192,11 @@ def human_readable_name(self) -> str:
21922192
"""
21932193
return self.name # type: ignore
21942194

2195-
def make_metavar(self) -> str:
2195+
def make_metavar(self, ctx: Context) -> str:
21962196
if self.metavar is not None:
21972197
return self.metavar
21982198

2199-
metavar = self.type.get_metavar(self)
2199+
metavar = self.type.get_metavar(param=self, ctx=ctx)
22002200

22012201
if metavar is None:
22022202
metavar = self.type.name.upper()
@@ -2775,7 +2775,7 @@ def _write_opts(opts: cabc.Sequence[str]) -> str:
27752775
any_prefix_is_slash = True
27762776

27772777
if not self.is_flag and not self.count:
2778-
rv += f" {self.make_metavar()}"
2778+
rv += f" {self.make_metavar(ctx=ctx)}"
27792779

27802780
return rv
27812781

@@ -3056,10 +3056,10 @@ def human_readable_name(self) -> str:
30563056
return self.metavar
30573057
return self.name.upper() # type: ignore
30583058

3059-
def make_metavar(self) -> str:
3059+
def make_metavar(self, ctx: Context) -> str:
30603060
if self.metavar is not None:
30613061
return self.metavar
3062-
var = self.type.get_metavar(self)
3062+
var = self.type.get_metavar(param=self, ctx=ctx)
30633063
if not var:
30643064
var = self.name.upper() # type: ignore
30653065
if self.deprecated:
@@ -3088,10 +3088,10 @@ def _parse_decls(
30883088
return name, [arg], []
30893089

30903090
def get_usage_pieces(self, ctx: Context) -> list[str]:
3091-
return [self.make_metavar()]
3091+
return [self.make_metavar(ctx)]
30923092

30933093
def get_error_hint(self, ctx: Context) -> str:
3094-
return f"'{self.make_metavar()}'"
3094+
return f"'{self.make_metavar(ctx)}'"
30953095

30963096
def add_to_parser(self, parser: _OptionParser, ctx: Context) -> None:
30973097
parser.add_argument(dest=self.name, nargs=self.nargs, obj=self)

src/click/exceptions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ def format_message(self) -> str:
174174

175175
msg = self.message
176176
if self.param is not None:
177-
msg_extra = self.param.type.get_missing_message(self.param)
177+
msg_extra = self.param.type.get_missing_message(
178+
param=self.param, ctx=self.ctx
179+
)
178180
if msg_extra:
179181
if msg:
180182
msg += f". {msg_extra}"

src/click/types.py

Lines changed: 96 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import collections.abc as cabc
4+
import enum
45
import os
56
import stat
67
import sys
@@ -23,6 +24,8 @@
2324
from .core import Parameter
2425
from .shell_completion import CompletionItem
2526

27+
ParamTypeValue = t.TypeVar("ParamTypeValue")
28+
2629

2730
class ParamType:
2831
"""Represents the type of a parameter. Validates and converts values
@@ -86,10 +89,10 @@ def __call__(
8689
if value is not None:
8790
return self.convert(value, param, ctx)
8891

89-
def get_metavar(self, param: Parameter) -> str | None:
92+
def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
9093
"""Returns the metavar default for this param if it provides one."""
9194

92-
def get_missing_message(self, param: Parameter) -> str | None:
95+
def get_missing_message(self, param: Parameter, ctx: Context | None) -> str | None:
9396
"""Optionally might return extra information about a missing
9497
parameter.
9598
@@ -227,29 +230,35 @@ def __repr__(self) -> str:
227230
return "STRING"
228231

229232

230-
class Choice(ParamType):
233+
class Choice(ParamType, t.Generic[ParamTypeValue]):
231234
"""The choice type allows a value to be checked against a fixed set
232-
of supported values. All of these values have to be strings.
233-
234-
You should only pass a list or tuple of choices. Other iterables
235-
(like generators) may lead to surprising results.
235+
of supported values.
236236
237-
The resulting value will always be one of the originally passed choices
238-
regardless of ``case_sensitive`` or any ``ctx.token_normalize_func``
239-
being specified.
237+
You may pass any iterable value which will be converted to a tuple
238+
and thus will only be iterated once.
240239
241-
See :ref:`choice-opts` for an example.
240+
The resulting value will always be one of the originally passed choices.
241+
See :meth:`normalize_choice` for more info on the mapping of strings
242+
to choices. See :ref:`choice-opts` for an example.
242243
243244
:param case_sensitive: Set to false to make choices case
244245
insensitive. Defaults to true.
246+
247+
.. versionchanged:: 8.2.0
248+
Non-``str`` ``choices`` are now supported. It can additionally be any
249+
iterable. Before you were not recommended to pass anything but a list or
250+
tuple.
251+
252+
.. versionadded:: 8.2.0
253+
Choice normalization can be overridden via :meth:`normalize_choice`.
245254
"""
246255

247256
name = "choice"
248257

249258
def __init__(
250-
self, choices: cabc.Sequence[str], case_sensitive: bool = True
259+
self, choices: cabc.Iterable[ParamTypeValue], case_sensitive: bool = True
251260
) -> None:
252-
self.choices = choices
261+
self.choices: cabc.Sequence[ParamTypeValue] = tuple(choices)
253262
self.case_sensitive = case_sensitive
254263

255264
def to_info_dict(self) -> dict[str, t.Any]:
@@ -258,14 +267,54 @@ def to_info_dict(self) -> dict[str, t.Any]:
258267
info_dict["case_sensitive"] = self.case_sensitive
259268
return info_dict
260269

261-
def get_metavar(self, param: Parameter) -> str:
270+
def _normalized_mapping(
271+
self, ctx: Context | None = None
272+
) -> cabc.Mapping[ParamTypeValue, str]:
273+
"""
274+
Returns mapping where keys are the original choices and the values are
275+
the normalized values that are accepted via the command line.
276+
277+
This is a simple wrapper around :meth:`normalize_choice`, use that
278+
instead which is supported.
279+
"""
280+
return {
281+
choice: self.normalize_choice(
282+
choice=choice,
283+
ctx=ctx,
284+
)
285+
for choice in self.choices
286+
}
287+
288+
def normalize_choice(self, choice: ParamTypeValue, ctx: Context | None) -> str:
289+
"""
290+
Normalize a choice value, used to map a passed string to a choice.
291+
Each choice must have a unique normalized value.
292+
293+
By default uses :meth:`Context.token_normalize_func` and if not case
294+
sensitive, convert it to a casefolded value.
295+
296+
.. versionadded:: 8.2.0
297+
"""
298+
normed_value = choice.name if isinstance(choice, enum.Enum) else str(choice)
299+
300+
if ctx is not None and ctx.token_normalize_func is not None:
301+
normed_value = ctx.token_normalize_func(normed_value)
302+
303+
if not self.case_sensitive:
304+
normed_value = normed_value.casefold()
305+
306+
return normed_value
307+
308+
def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
262309
if param.param_type_name == "option" and not param.show_choices: # type: ignore
263310
choice_metavars = [
264311
convert_type(type(choice)).name.upper() for choice in self.choices
265312
]
266313
choices_str = "|".join([*dict.fromkeys(choice_metavars)])
267314
else:
268-
choices_str = "|".join([str(i) for i in self.choices])
315+
choices_str = "|".join(
316+
[str(i) for i in self._normalized_mapping(ctx=ctx).values()]
317+
)
269318

270319
# Use curly braces to indicate a required argument.
271320
if param.required and param.param_type_name == "argument":
@@ -274,46 +323,48 @@ def get_metavar(self, param: Parameter) -> str:
274323
# Use square braces to indicate an option or optional argument.
275324
return f"[{choices_str}]"
276325

277-
def get_missing_message(self, param: Parameter) -> str:
278-
return _("Choose from:\n\t{choices}").format(choices=",\n\t".join(self.choices))
326+
def get_missing_message(self, param: Parameter, ctx: Context | None) -> str:
327+
"""
328+
Message shown when no choice is passed.
329+
330+
.. versionchanged:: 8.2.0 Added ``ctx`` argument.
331+
"""
332+
return _("Choose from:\n\t{choices}").format(
333+
choices=",\n\t".join(self._normalized_mapping(ctx=ctx).values())
334+
)
279335

280336
def convert(
281337
self, value: t.Any, param: Parameter | None, ctx: Context | None
282-
) -> t.Any:
283-
# Match through normalization and case sensitivity
284-
# first do token_normalize_func, then lowercase
285-
# preserve original `value` to produce an accurate message in
286-
# `self.fail`
287-
normed_value = value
288-
normed_choices = {choice: choice for choice in self.choices}
289-
290-
if ctx is not None and ctx.token_normalize_func is not None:
291-
normed_value = ctx.token_normalize_func(value)
292-
normed_choices = {
293-
ctx.token_normalize_func(normed_choice): original
294-
for normed_choice, original in normed_choices.items()
295-
}
296-
297-
if not self.case_sensitive:
298-
normed_value = normed_value.casefold()
299-
normed_choices = {
300-
normed_choice.casefold(): original
301-
for normed_choice, original in normed_choices.items()
302-
}
303-
304-
if normed_value in normed_choices:
305-
return normed_choices[normed_value]
338+
) -> ParamTypeValue:
339+
"""
340+
For a given value from the parser, normalize it and find its
341+
matching normalized value in the list of choices. Then return the
342+
matched "original" choice.
343+
"""
344+
normed_value = self.normalize_choice(choice=value, ctx=ctx)
345+
normalized_mapping = self._normalized_mapping(ctx=ctx)
306346

307-
self.fail(self.get_invalid_choice_message(value), param, ctx)
347+
try:
348+
return next(
349+
original
350+
for original, normalized in normalized_mapping.items()
351+
if normalized == normed_value
352+
)
353+
except StopIteration:
354+
self.fail(
355+
self.get_invalid_choice_message(value=value, ctx=ctx),
356+
param=param,
357+
ctx=ctx,
358+
)
308359

309-
def get_invalid_choice_message(self, value: t.Any) -> str:
360+
def get_invalid_choice_message(self, value: t.Any, ctx: Context | None) -> str:
310361
"""Get the error message when the given choice is invalid.
311362
312363
:param value: The invalid value.
313364
314365
.. versionadded:: 8.2
315366
"""
316-
choices_str = ", ".join(map(repr, self.choices))
367+
choices_str = ", ".join(map(repr, self._normalized_mapping(ctx=ctx).values()))
317368
return ngettext(
318369
"{value!r} is not {choice}.",
319370
"{value!r} is not one of {choices}.",
@@ -382,7 +433,7 @@ def to_info_dict(self) -> dict[str, t.Any]:
382433
info_dict["formats"] = self.formats
383434
return info_dict
384435

385-
def get_metavar(self, param: Parameter) -> str:
436+
def get_metavar(self, param: Parameter, ctx: Context) -> str | None:
386437
return f"[{'|'.join(self.formats)}]"
387438

388439
def _try_to_convert_date(self, value: t.Any, format: str) -> datetime | None:

0 commit comments

Comments
 (0)