Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

package com.google.cloud.vertexai.generativeai;

import static com.google.common.base.Preconditions.checkArgument;

import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.Part;
import com.google.common.base.Strings;

/** Helper class to create content. */
public class ContentMaker {
private static String role = "user";
private static final String DEFAULT_ROLE = "user";

/**
* Creates a ContentMakerForRole for a given role.
Expand All @@ -34,6 +37,7 @@ public static ContentMakerForRole forRole(String role) {
}

private static Content fromStringWithRole(String role, String text) {
checkArgument(!Strings.isNullOrEmpty(text), "text message can't be null or empty.");
return Content.newBuilder().addParts(Part.newBuilder().setText(text)).setRole(role).build();
}

Expand Down Expand Up @@ -61,7 +65,7 @@ private static Content fromMultiModalDataWithRole(String role, Object... multiMo
* <p>To create a text content for "model", use `ContentMaker.forRole("model").fromString(text);
*/
public static Content fromString(String text) {
return fromStringWithRole(role, text);
return fromStringWithRole(DEFAULT_ROLE, text);
}

/**
Expand All @@ -76,8 +80,9 @@ public static Content fromString(String text) {
* could be either a single String or a Part. When it's a single string, it's converted to a
* {@link com.google.cloud.vertexai.api.Part} that has the Text field set.
*/
// TODO(b/333097480) Deprecate ContentMakerForRole
public static Content fromMultiModalData(Object... multiModalData) {
return fromMultiModalDataWithRole(role, multiModalData);
return fromMultiModalDataWithRole(DEFAULT_ROLE, multiModalData);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ public GenerativeModel withTools(List<Tool> tools) {
*/
@BetaApi
public CountTokensResponse countTokens(String text) throws IOException {
// TODO(b/330402637): Check null and empty values for the input string.
return countTokens(ContentMaker.fromString(text));
}

Expand All @@ -255,6 +254,7 @@ public CountTokensResponse countTokens(Content content) throws IOException {
*/
@BetaApi
public CountTokensResponse countTokens(List<Content> contents) throws IOException {
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");
CountTokensRequest request =
CountTokensRequest.newBuilder()
.setEndpoint(resourceName)
Expand Down Expand Up @@ -287,7 +287,6 @@ private CountTokensResponse countTokensFromRequest(CountTokensRequest request)
* @throws IOException if an I/O error occurs while making the API call
*/
public GenerateContentResponse generateContent(String text) throws IOException {
// TODO(b/330402637): Check null and empty values for the input string.
return generateContent(ContentMaker.fromString(text));
}

Expand Down Expand Up @@ -447,6 +446,7 @@ private ApiFuture<GenerateContentResponse> generateContentAsync(GenerateContentR
* contents and model configurations.
*/
private GenerateContentRequest buildGenerateContentRequest(List<Content> contents) {
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");
return GenerateContentRequest.newBuilder()
.setModel(resourceName)
.addAllContents(contents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.vertexai.generativeai;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.cloud.vertexai.api.Content;
import com.google.protobuf.ByteString;
Expand All @@ -38,6 +39,24 @@ public void fromString_returnsContentWithText() {
assertThat(content.getParts(0).getText()).isEqualTo(stringInput);
}

@Test
public void fromString_throwsIllegalArgumentException_withEmptyString() {
String stringInput = "";

IllegalArgumentException thrown =
assertThrows(IllegalArgumentException.class, () -> ContentMaker.fromString(stringInput));
assertThat(thrown).hasMessageThat().isEqualTo("text message can't be null or empty.");
}

@Test
public void fromString_throwsIllegalArgumentException_withNullString() {
String stringInput = null;

IllegalArgumentException thrown =
assertThrows(IllegalArgumentException.class, () -> ContentMaker.fromString(stringInput));
assertThat(thrown).hasMessageThat().isEqualTo("text message can't be null or empty.");
}

@Test
public void forRole_returnsContentWithArbitraryRoleSet() {
// Although in our docstring, we said only three roles are acceptable, we make sure the code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import com.google.cloud.vertexai.api.Type;
import com.google.cloud.vertexai.api.VertexAISearch;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -452,6 +453,16 @@ public void testGenerateContentwithFluentApi() throws Exception {
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void generateContent_withNullContents_throws() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
List<Content> contents = null;

IllegalArgumentException thrown =
assertThrows(IllegalArgumentException.class, () -> model.generateContent(contents));
assertThat(thrown).hasMessageThat().isEqualTo("contents can't be null or empty.");
}

@Test
public void testGenerateContentStreamwithText() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
Expand Down Expand Up @@ -636,6 +647,16 @@ public void testGenerateContentStreamwithFluentApi() throws Exception {
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void generateContentStream_withEmptyContents_throws() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
List<Content> contents = new ArrayList<>();

IllegalArgumentException thrown =
assertThrows(IllegalArgumentException.class, () -> model.generateContentStream(contents));
assertThat(thrown).hasMessageThat().isEqualTo("contents can't be null or empty.");
}

@Test
public void generateContentAsync_withText_sendsCorrectRequest() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
Expand Down