Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion gemma/gm/text/_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,28 @@ def sample(

# TODO(epot): Donate the `init_state`, `last_state`

sampler = _sampler_loop.SamplerLoop(
# Static attributes. Changing those will trigger a recompilation.
model=self.model,
end_tokens=(
self.tokenizer.special_tokens.EOS,
self.tokenizer.special_tokens.END_OF_TURN,
# BEGIN_OF_TOOL_RESPONSE was introduced in Gemma3; Gemma2 tokenizer
# does not define it. Only include it when available.
*(
(self.tokenizer.special_tokens.BEGIN_OF_TOOL_RESPONSE,)
if hasattr(
self.tokenizer.special_tokens, 'BEGIN_OF_TOOL_RESPONSE'
)
else ()
),
*self._normalized_stop_tokens,
),
forbidden_tokens=self._normalized_forbidden_tokens,
sampling=sampling,
cache_length=self.cache_length,
special_tokens=self.tokenizer.special_tokens,
)
sampler = self._initialize_sampler_loop(sampling)

# TODO(epot): Use `jnp.cond` to detect when the cache is full (or use
Expand Down Expand Up @@ -577,7 +599,7 @@ def _normalize_token(tokenizer, token: str | int) -> int:
token_id = tokenizer.encode(token)
if len(token_id) != 1:
raise ValueError(
'Invalid token: {token!r}. `stop_token`s and `forbidden_token`s must'
f'Invalid token: {token!r}. `stop_token`s and `forbidden_token`s must'
' map to single token ids in the vocab.'
)
(token_id,) = token_id
Expand Down
39 changes: 39 additions & 0 deletions gemma/gm/text/_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.

from gemma import gm
from gemma.gm.text import _sampler
from gemma.gm.text import _sampler_loop
from gemma.gm.text import _tokenizer
import jax
import jax.numpy as jnp
import numpy as np
import pytest


def test_end_tokens_mask():
Expand Down Expand Up @@ -54,3 +57,39 @@ def test_sampler():
pad_length=None,
)
sampler.sample('Hello world')


def test_normalize_token_error_message_contains_token_value():
"""_normalize_token should interpolate the token value in the error message.

Regression test for a missing f-string prefix that caused the error message
to show the literal string '{token!r}' instead of the actual token value.
"""
tokenizer = gm.testing.DummyTokenizer()
# 'Hello world' encodes to two tokens (the dummy tokenizer splits on spaces),
# so _normalize_token should raise ValueError for it.
with pytest.raises(ValueError, match=r'Hello world'):
_sampler._normalize_token(tokenizer, 'Hello world')


def test_sampler_gemma2_tokenizer_no_begin_of_tool_response():
"""Sampler with Gemma2 tokenizer must not crash on BEGIN_OF_TOOL_RESPONSE.

Gemma2's _Gemma2SpecialTokens does not define BEGIN_OF_TOOL_RESPONSE (that
attribute was introduced in Gemma3). The Sampler.sample() method previously
accessed it unconditionally, raising AttributeError for any Gemma2 model.
This test verifies the hasattr() guard prevents the crash.
"""
# DummyTokenizer uses _Gemma3SpecialTokens which has the attribute.
# We verify the guard logic directly: _Gemma2SpecialTokens must NOT have it.
assert not hasattr(
_tokenizer._Gemma2SpecialTokens, 'BEGIN_OF_TOOL_RESPONSE'
), (
'_Gemma2SpecialTokens should not define BEGIN_OF_TOOL_RESPONSE'
)
# And _Gemma3SpecialTokens MUST have it.
assert hasattr(
_tokenizer._Gemma3SpecialTokens, 'BEGIN_OF_TOOL_RESPONSE'
), (
'_Gemma3SpecialTokens should define BEGIN_OF_TOOL_RESPONSE'
)