title | description | has_children | parent | grand_parent | nav_order |
---|---|---|---|---|---|
Java API |
Java API reference for ONNX Runtime generate() API |
false |
API docs |
Generate API (Preview) |
4 |
Note: this API is in preview and is subject to change.
{: .no_toc }
- TOC placeholder {:toc}
The Java API is delivered by the ai.onnxruntime.genai Java package. Package publication is pending. To build the package from source, see the build from source guide.
import ai.onnxruntime.genai.*;
The SimpleGenAI
class provides a simple usage example of the GenAI API. It works with a model that generates text based on a prompt, processing a single prompt at a time.
Usage:
Create an instance of the class with the path to the model. The path should also contain the GenAI configuration files.
SimpleGenAI genAI = new SimpleGenAI(folderPath);
Call createGeneratorParams with the prompt text.
Set any other search options via the GeneratorParams object as needed using setSearchOption
.
GeneratorParams generatorParams = genAI.createGeneratorParams(promptText);
// .. set additional generator params before calling generate()
Call generate with the GeneratorParams object and an optional listener.
String fullResponse = genAI.generate(generatorParams, listener);
The listener is used as a callback mechanism so that tokens can be used as they are generated. Create a class that implements the Consumer<String>
interface and provide an instance of that class as the listener
argument.
public SimpleGenAI(String modelPath) throws GenAIException
GenAIException
- on failure.
Generate text based on the prompt and settings in GeneratorParams.
NOTE: This only handles a single sequence of input (i.e. a single prompt which equates to batch size of 1).
public String generate(GeneratorParams generatorParams, Consumer<String> listener) throws GenAIException
generatorParams
: the prompt and settings to run the model with.listener
: optional callback for tokens to be provided as they are generated.
NOTE: Token generation will be blocked until the listener's accept
method returns.
GenAIException
- on failure.
The generated text.
SimpleGenAI generator = new SimpleGenAI(modelPath);
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Consumer<String> listener = token -> logger.info("onTokenGenerate: " + token);
String result = generator.generate(params, listener);
logger.info("Result: " + result);
Create the generator parameters and add the prompt text. The user can set other search options via the GeneratorParams object prior to running generate
.
public GeneratorParams createGeneratorParams(String prompt) throws GenAIException
prompt
: the prompt text to encode.
GenAIException
- on failure.
The generator parameters.
An exception which contains the error message and code produced by the native layer.
public GenAIException(String message)
catch (GenAIException e) {
throw new GenAIException("Token generation loop failed.", e);
}
Model(String modelPath)
Creates a Tokenizer instance for this model. The model contains the configuration information that determines the tokenizer to use.
public Tokenizer createTokenizer() throws GenAIException
GenAIException
- if the call to the GenAI native API fails
The tokenizer instance.
public Sequences generate(GeneratorParams generatorParams) throws GenAIException
generatorParams
: the generator parameters.
GenAIException
- if the call to the GenAI native API fails.
The generated sequences.
Sequences output = model.generate(generatorParams);
Creates a GeneratorParams instance for executing the model.
NOTE: GeneratorParams internally uses the Model, so the Model instance must remain valid.
public GeneratorParams createGeneratorParams() throws GenAIException
GenAIException
- if the call to the GenAI native API fails.
The GeneratorParams instance.
GeneratorParams params = generator.createGeneratorParams("What's 6 times 7?");
Encodes a string into a sequence of token ids.
public Sequences encode(String string) throws GenAIException
string
: text to encode as token ids.
GenAIException
- if the call to the GenAI native API fails.
A Sequences object with a single sequence in it.
Sequences encodedPrompt = tokenizer.encode(prompt);
Decodes a sequence of token ids into text.
public String decode(int[] sequence) throws GenAIException
sequence
: collection of token ids to decode to text
GenAIException
- if the call to the GenAI native API fails.
The text representation of the sequence.
String result = tokenizer.decode(output_ids);
Encodes an array of strings into a sequence of token ids for each input.
public Sequences encodeBatch(String[] strings) throws GenAIException
strings
: collection of strings to encode as token ids.
GenAIException
- if the call to the GenAI native API fails.
A Sequences object with one sequence per input string.
Sequences encoded = tokenizer.encodeBatch(inputs);
Decodes a batch of sequences of token ids into text.
public String[] decodeBatch(Sequences sequences) throws GenAIException
sequences
: a Sequences object with one or more sequences of token ids.
GenAIException
- if the call to the GenAI native API fails.
An array of strings with the text representation of each sequence.
String[] decoded = tokenizer.decodeBatch(encoded);
Creates a TokenizerStream object for streaming tokenization. This is used with Generator class to provide each token as it is generated.
public TokenizerStream createStream() throws GenAIException
GenAIException
- if the call to the GenAI native API fails.
The new TokenizerStream instance.
This class is used to convert individual tokens when using Generator.generateNextToken.
public String decode(int token) throws GenAIException
token
: int value for token
GenAIException
Constructs a Tensor with the given data, shape, and element type.
public Tensor(ByteBuffer data, long[] shape, ElementType elementType) throws GenAIException
data
: the data for the Tensor. Must be a direct ByteBuffer.shape
: the shape of the Tensor.elementType
: the Type of elements in the Tensor.
GenAIException
Create a 2x2 Tensor with 32-bit float data.
long[] shape = {2, 2};
ByteBuffer data = ByteBuffer.allocateDirect(4 * Float.BYTES);
FloatBuffer floatBuffer = data.asFloatBuffer();
floatBuffer.put(new float[] {1.0f, 2.0f, 3.0f, 4.0f});
Tensor tensor = new Tensor(data, shape, Tensor.ElementType.float32);
The GeneratorParams
class represents the parameters used for generating sequences with a model. Set the prompt using setInput, and any other search options using setSearchOption.
GeneratorParams params = new GeneratorParams(model);
public void setSearchOption(String optionName, double value) throws GenAIException
GenAIException
Set search option to limit the model generation length.
generatorParams.setSearchOption("max_length", 10);
public void setSearchOption(String optionName, boolean value) throws GenAIException
GenAIException
generatorParams.setSearchOption("early_stopping", true);
Sets the prompt/s for model execution. The sequences
are created by using Tokenizer.Encode or EncodeBatch.
public void setInput(Sequences sequences) throws GenAIException
sequences
: sequences containing the encoded prompt.
GenAIException
- if the call to the GenAI native API fails.
generatorParams.setInput(encodedPrompt);
Sets the prompt/s token ids for model execution. The tokenIds
are the encoded parameters.
public void setInput(int[] tokenIds, int sequenceLength, int batchSize)
throws GenAIException
tokenIds
: the token ids of the encoded prompt/ssequenceLength
: the length of each sequence.batchSize
: size of the batch.
GenAIException
- if the call to the GenAI native API fails.
NOTE: all sequences in the batch must be the same length.
generatorParams.setInput(tokenIds, sequenceLength, batchSize);
The Generator class generates output using a model and generator parameters. The expected usage is to loop until isDone returns false. Within the loop, call computeLogits followed by generateNextToken.
The newly generated token can be retrieved with getLastTokenInSequence and decoded with TokenizerStream.Decode.
After the generation process is done, GetSequence can be used to retrieve the complete generated sequence if needed.
Constructs a Generator object with the given model and generator parameters.
Generator(Model model, GeneratorParams generatorParams)
model
: the model.params
: the generator parameters.
GenAIException
- if the call to the GenAI native API fails.
Checks if the generation process is done.
public boolean isDone()
Returns true if the generation process is done, false otherwise.
Computes the logits for the next token in the sequence.
public void computeLogits() throws GenAIException
GenAIException
- if the call to the GenAI native API fails.
Retrieves a sequence of token ids for the specified sequence index.
public int[] getSequence(long sequenceIndex) throws GenAIException
sequenceIndex
: the index of the sequence.
GenAIException
- if the call to the GenAI native API fails.
An array of integers with the sequence of token ids.
int[] outputIds = output.getSequence(i);
Generates the next token in the sequence.
public void generateNextToken() throws GenAIException
GenAIException
- if the call to the GenAI native API fails.
Retrieves the last token in the sequence for the specified sequence index.
public int getLastTokenInSequence(long sequenceIndex) throws GenAIException
sequenceIndex
: the index of the sequence.
GenAIException
- if the call to the GenAI native API fails.
The last token in the sequence.
Represents a collection of encoded prompts/responses.
Gets the number of sequences in the collection. This is equivalent to the batch size.
public long numSequences()
The number of sequences.
int numSequences = (int) sequences.numSequences();
Gets the sequence at the specified index.
public int[] getSequence(long sequenceIndex)
sequenceIndex
: The index of the sequence.
The sequence as an array of integers.
Coming very soon!