Control generated text using logits processor#
Source NVIDIA/TensorRT-LLM.
1### Control generated text using logits processor
2from typing import List, Optional
3
4import torch
5
6from tensorrt_llm import LLM
7from tensorrt_llm.sampling_params import (BatchedLogitsProcessor,
8 LogitsProcessor, SamplingParams)
9
10
11# The recommended way to create a customized logits processor:
12# * Subclass LogitsProcessor and implement the processing logics in the __call__ method.
13# * Create an instance and pass to SamplingParams.
14# Alternatively, you can create any callable with the same signature with the __call__ method.
15# This simple callback will output a specific token at each step irrespective of prompt.
16# Refer to ../bindings/executor/example_logits_processor.py for a more
17# sophisticated callback that generates JSON structured output.
18# Please also refer to sampling_params.py for adding subclass to the approved class list for deserialization
19class MyLogitsProcessor(LogitsProcessor):
20
21 def __init__(self, allowed_token_id: int):
22 self.allowed_token_id = allowed_token_id
23
24 def __call__(self, req_id: int, logits: torch.Tensor,
25 token_ids: List[List[int]], stream_ptr: int,
26 client_id: Optional[int]):
27 mask = torch.full_like(logits, fill_value=float("-inf"), device="cpu")
28 mask[:, :, self.allowed_token_id] = 0
29
30 stream = None if stream_ptr is None else torch.cuda.ExternalStream(
31 stream_ptr)
32 with torch.cuda.stream(stream):
33 mask = mask.to(logits.device, non_blocking=True)
34 logits += mask
35
36
37# The recommended way to create a customized batched logits processor:
38# * Subclass BatchedLogitsProcessor and implement the processing logics in the __call__ method.
39# * Create an instance and pass to LLM.
40# Alternatively, you can create any callable with the same signature with the __call__ method.
41# A batched logits processor's arguments for all requests in a batch are made available as lists.
42# This helps user optimize the callback for large batch sizes. For example:
43# 1. Process more work on host, e.g. running a JSON state machine, in parallel with model forward pass on device.
44# 2. Coalesce H2D memory transfers for all requests into a single cudaMemcpyAsync call.
45# 3. Launch a single batched kernel, e.g. for updating logits on device.
46class MyBatchedLogitsProcessor(BatchedLogitsProcessor):
47
48 def __init__(self, allowed_token_id: int):
49 self.allowed_token_id = allowed_token_id
50
51 def __call__(self, req_ids: List[int], logits: List[torch.Tensor],
52 token_ids: List[List[List[int]]], stream_ptr: int,
53 client_ids: List[Optional[int]]):
54 # Generate masks for all requests on host
55 masks = []
56 for req_id, req_logits, req_token_ids, client_id in zip(
57 req_ids, logits, token_ids, client_ids):
58 mask = torch.full_like(req_logits,
59 fill_value=float("-inf"),
60 device="cpu")
61 mask[:, :, self.allowed_token_id] = 0
62 masks.append(mask)
63
64 # Move masks to device and add to logits using non-blocking operations
65 with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
66 for req_logits, mask in zip(logits, masks):
67 req_logits += mask.to(req_logits.device, non_blocking=True)
68
69
70def main():
71
72 # Batched logits processor (only supported in TensorRT backend)
73 # should be specified when initializing LLM.
74 llm = LLM(
75 model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
76 batched_logits_processor=MyBatchedLogitsProcessor(allowed_token_id=42))
77
78 # Sample prompts
79 prompts = [
80 "Hello, my name is",
81 "The president of the United States is",
82 ]
83
84 # Generate text
85 for prompt_id, prompt in enumerate(prompts):
86 # Use non-batched logits processor callback only for odd-numbered prompts
87 if prompt_id % 2 == 0:
88 sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
89 else:
90 # Each prompt can be specified with a logits processor at runtime
91 sampling_params = SamplingParams(
92 temperature=0.8,
93 top_p=0.95,
94 logits_processor=MyLogitsProcessor(allowed_token_id=42))
95
96 for output in llm.generate([prompt], sampling_params):
97 print(
98 f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
99 )
100
101 # Got output like
102 # Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
103 # Prompt: 'The president of the United States is', Generated text: "''''''''''''''''''''''''''''''''"
104
105 # Use batched processor with batch size = 2
106 sampling_params = SamplingParams(apply_batched_logits_processor=True)
107 for output in llm.generate(prompts, sampling_params):
108 print(
109 f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
110 )
111
112 # Got output like
113 # Prompt: 'Hello, my name is', Generated text: "''''''''''''''''''''''''''''''''"
114 # Prompt: 'The president of the United States is', Generated text: "''''''''''''''''''''''''''''''''"
115
116
117if __name__ == '__main__':
118 main()