diff --git a/vertexai/snippets/src/test/java/vertexai/gemini/SnippetsIT.java b/vertexai/snippets/src/test/java/vertexai/gemini/SnippetsIT.java index d39e140fe03..6a77706bd96 100644 --- a/vertexai/snippets/src/test/java/vertexai/gemini/SnippetsIT.java +++ b/vertexai/snippets/src/test/java/vertexai/gemini/SnippetsIT.java @@ -47,8 +47,8 @@ public class SnippetsIT { private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT"); private static final String LOCATION = "us-central1"; - private static final String GEMINI_FLASH = "gemini-1.5-flash-001"; - private static final String GEMINI_PRO = "gemini-1.5-pro-001"; + private static final String GEMINI_FLASH = "gemini-2.0-flash-001"; + private static final String GEMINI_FLASH_1_5 = "gemini-1.5-flash-001"; private static final String DATASTORE_ID = "grounding-test-datastore_1716831150046"; private static final int MAX_ATTEMPT_COUNT = 3; private static final int INITIAL_BACKOFF_MILLIS = 120000; @@ -183,7 +183,7 @@ public void testSimpleQuestionAnswer() throws Exception { @Test public void testQuickstart() throws IOException { String output = Quickstart.quickstart(PROJECT_ID, LOCATION, GEMINI_FLASH); - assertThat(output).contains("scones"); + assertThat(output).contains("cookie"); } @Test @@ -229,7 +229,7 @@ public void testTokenCount() throws Exception { @Test public void testMediaTokenCount() throws Exception { int tokenCount = GetMediaTokenCount.getMediaTokenCount(PROJECT_ID, LOCATION, GEMINI_FLASH); - assertThat(tokenCount).isEqualTo(16822); + assertThat(tokenCount).isEqualTo(16252); } @Test @@ -314,7 +314,7 @@ public void testSystemInstruction() throws Exception { @Test public void testGroundingWithPublicData() throws Exception { String output = - GroundingWithPublicData.groundWithPublicData(PROJECT_ID, LOCATION, GEMINI_FLASH); + GroundingWithPublicData.groundWithPublicData(PROJECT_ID, LOCATION, GEMINI_FLASH_1_5); assertThat(output).ignoringCase().contains("Rayleigh"); } @@ -364,7 +364,7 @@ public void testControlledGenerationWithMimeType() throws Exception { @Test public void testControlledGenerationWithJsonSchema() throws Exception { String output = ControlledGenerationSchema - .controlGenerationWithJsonSchema(PROJECT_ID, LOCATION, GEMINI_PRO); + .controlGenerationWithJsonSchema(PROJECT_ID, LOCATION, GEMINI_FLASH); Recipe[] recipes = new Gson().fromJson(output, Recipe[].class); assertThat(recipes).isNotEmpty(); @@ -379,7 +379,7 @@ private class Review { @Test public void testControlledGenerationWithJsonSchema2() throws Exception { String output = ControlledGenerationSchema2 - .controlGenerationWithJsonSchema2(PROJECT_ID, LOCATION, GEMINI_PRO); + .controlGenerationWithJsonSchema2(PROJECT_ID, LOCATION, GEMINI_FLASH); Review[] recipes = new Gson().fromJson(output, Review[].class); assertThat(recipes).hasLength(2); @@ -409,7 +409,7 @@ private class DayForecast { @Test public void testControlledGenerationWithJsonSchema3() throws Exception { String output = ControlledGenerationSchema3 - .controlGenerationWithJsonSchema3(PROJECT_ID, LOCATION, GEMINI_PRO); + .controlGenerationWithJsonSchema3(PROJECT_ID, LOCATION, GEMINI_FLASH); WeatherForecast weatherForecast = new Gson().fromJson(output, WeatherForecast.class); assertThat(weatherForecast.forecast).hasLength(7); @@ -473,7 +473,7 @@ private class Item { @Test public void testControlledGenerationWithJsonSchema4() throws Exception { String output = ControlledGenerationSchema4 - .controlGenerationWithJsonSchema4(PROJECT_ID, LOCATION, GEMINI_PRO); + .controlGenerationWithJsonSchema4(PROJECT_ID, LOCATION, GEMINI_FLASH); Item[] items = new Gson().fromJson(output, Item[].class); assertThat(items).isNotEmpty(); @@ -486,7 +486,7 @@ private class Obj { @Test public void testControlledGenerationWithJsonSchema6() throws Exception { String output = ControlledGenerationSchema6 - .controlGenerationWithJsonSchema6(PROJECT_ID, LOCATION, GEMINI_PRO); + .controlGenerationWithJsonSchema6(PROJECT_ID, LOCATION, GEMINI_FLASH); Obj[] objects = new Gson().fromJson(output, Obj[].class); String recognizedObjects = Arrays.stream(objects) @@ -503,7 +503,7 @@ public void testControlledGenerationWithJsonSchema6() throws Exception { @Test public void testGeminiTranslate() throws Exception { String output = GeminiTranslate.geminiTranslate( - PROJECT_ID, LOCATION, GEMINI_PRO, TEXT_TO_TRANSLATE, TARGET_LANGUAGE_CODE); + PROJECT_ID, LOCATION, GEMINI_FLASH, TEXT_TO_TRANSLATE, TARGET_LANGUAGE_CODE); assertThat(output).ignoringCase().contains("Bonjour"); assertThat(output).ignoringCase().contains("aujourd'hui");