Control generated text using logits processor#

Source NVIDIA/TensorRT-LLM.

  1### Control generated text using logits processor
  2from typing import List, Optional
  3
  4import torch
  5
  6from tensorrt_llm import LLM
  7from tensorrt_llm.sampling_params import (BatchedLogitsProcessor,
  8                                          LogitsProcessor, SamplingParams)
  9
 10
 11# The recommended way to create a customized logits processor:
 12#     * Subclass LogitsProcessor and implement the processing logics in the __call__ method.
 13#     * Create an instance and pass to SamplingParams.
 14# Alternatively, you can create any callable with the same signature with the __call__ method.
 15# This simple callback will output a specific token at each step irrespective of prompt.
 16# Refer to ../bindings/executor/example_logits_processor.py for a more
 17# sophisticated callback that generates JSON structured output.
 18# Please also refer to sampling_params.py for adding subclass to the approved class list for deserialization
 19class MyLogitsProcessor(LogitsProcessor):
 20
 21    def __init__(self, allowed_token_id: int):
 22        self.allowed_token_id = allowed_token_id
 23
 24    def __call__(self, req_id: int, logits: torch.Tensor,
 25                 token_ids: List[List[int]], stream_ptr: int,
 26                 client_id: Optional[int]):
 27        mask = torch.full_like(logits, fill_value=float("-inf"), device="cpu")
 28        mask[:, :, self.allowed_token_id] = 0
 29
 30        stream = None if stream_ptr is None else torch.cuda.ExternalStream(
 31            stream_ptr)
 32        with torch.cuda.stream(stream):
 33            mask = mask.to(logits.device, non_blocking=True)
 34            logits += mask
 35
 36
 37# The recommended way to create a customized batched logits processor:
 38#     * Subclass BatchedLogitsProcessor and implement the processing logics in the __call__ method.
 39#     * Create an instance and pass to LLM.
 40# Alternatively, you can create any callable with the same signature with the __call__ method.
 41# A batched logits processor's arguments for all requests in a batch are made available as lists.
 42# This helps user optimize the callback for large batch sizes. For example:
 43# 1. Process more work on host, e.g. running a JSON state machine, in parallel with model forward pass on device.
 44# 2. Coalesce H2D memory transfers for all requests into a single cudaMemcpyAsync call.
 45# 3. Launch a single batched kernel, e.g. for updating logits on device.
 46class MyBatchedLogitsProcessor(BatchedLogitsProcessor):
 47
 48    def __init__(self, allowed_token_id: int):
 49        self.allowed_token_id = allowed_token_id
 50
 51    def __call__(self, req_ids: List[int], logits: List[torch.Tensor],
 52                 token_ids: List[List[List[int]]], stream_ptr: int,
 53                 client_ids: List[Optional[int]]):
 54        # Generate masks for all requests on host
 55        masks = []
 56        for req_id, req_logits, req_token_ids, client_id in zip(
 57                req_ids, logits, token_ids, client_ids):
 58            mask = torch.full_like(req_logits,
 59                                   fill_value=float("-inf"),
 60                                   device="cpu")
 61            mask[:, :, self.allowed_token_id] = 0
 62            masks.append(mask)
 63
 64        # Move masks to device and add to logits using non-blocking operations
 65        with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
 66            for req_logits, mask in zip(logits, masks):
 67                req_logits += mask.to(req_logits.device, non_blocking=True)
 68
 69
 70def main():
 71
 72    # Batched logits processor (only supported in TensorRT backend)
 73    # should be specified when initializing LLM.
 74    llm = LLM(
 75        model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
 76        batched_logits_processor=MyBatchedLogitsProcessor(allowed_token_id=42))
 77
 78    # Sample prompts
 79    prompts = [
 80        "Hello, my name is",
 81        "The president of the United States is",
 82    ]
 83
 84    # Generate text
 85    for prompt_id, prompt in enumerate(prompts):
 86        # Use non-batched logits processor callback only for odd-numbered prompts
 87        if prompt_id % 2 == 0:
 88            sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 89        else:
 90            # Each prompt can be specified with a logits processor at runtime
 91            sampling_params = SamplingParams(
 92                temperature=0.8,
 93                top_p=0.95,
 94                logits_processor=MyLogitsProcessor(allowed_token_id=42))
 95
 96        for output in llm.generate([prompt], sampling_params):
 97            print(
 98                f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
 99            )
100
101    # Got output like
102    # Prompt: 'Hello, my name is', Generated text: '\n\nJane Smith. I am a student pursuing my degree in Computer Science at [university]. I enjoy learning new things, especially technology and programming'
103    # Prompt: 'The president of the United States is', Generated text: "''''''''''''''''''''''''''''''''"
104
105    # Use batched processor with batch size = 2
106    sampling_params = SamplingParams(apply_batched_logits_processor=True)
107    for output in llm.generate(prompts, sampling_params):
108        print(
109            f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}"
110        )
111
112    # Got output like
113    # Prompt: 'Hello, my name is', Generated text: "''''''''''''''''''''''''''''''''"
114    # Prompt: 'The president of the United States is', Generated text: "''''''''''''''''''''''''''''''''"
115
116
117if __name__ == '__main__':
118    main()