Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import com.google.genai.types.FunctionCall;
import com.google.genai.types.Part;
import java.util.Base64;
Expand All @@ -37,6 +41,10 @@ private ChatCompletionsCommon() {}

private static final ObjectMapper objectMapper = new ObjectMapper();

static final String EMPTY_JSON_OBJECT = "{}";
static final ImmutableMap<String, Object> EMPTY_PARAMETERS_SCHEMA =
ImmutableMap.of("type", "object", "properties", ImmutableMap.of());

public static final String ROLE_ASSISTANT = "assistant";
public static final String ROLE_MODEL = "model";

Expand Down Expand Up @@ -157,6 +165,21 @@ public Part applyThoughtSignature(Part part) {
}
}

static ImmutableMap<String, Object> parseToolCallArguments(String arguments, ObjectMapper mapper)
throws JsonProcessingException {
if (arguments == null || arguments.trim().isEmpty()) {
return ImmutableMap.of();
}
Map<String, Object> result =
mapper.readValue(arguments, new TypeReference<Map<String, Object>>() {});
if (result == null) {
throw JsonMappingException.from(
(JsonParser) null,
"JSON literal 'null' is not a valid JSON object for tool call arguments");
}
return ImmutableMap.copyOf(result);
}

/**
* See
* https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_function_tool_call%20%3E%20(schema)
Expand All @@ -181,21 +204,21 @@ public FunctionCall toFunctionCall(@Nullable String toolCallId) {
if (name != null) {
fcBuilder.name(name);
}
if (arguments != null && !arguments.isEmpty()) {
try {
Map<String, Object> args =
objectMapper.readValue(arguments, new TypeReference<Map<String, Object>>() {});
fcBuilder.args(args);
} catch (Exception e) {
throw new IllegalArgumentException(
"Failed to parse function arguments JSON: " + arguments, e);
}
}
fcBuilder.args(parseArguments(arguments));
if (toolCallId != null) {
fcBuilder.id(toolCallId);
}
return fcBuilder.build();
}

private ImmutableMap<String, Object> parseArguments(String arguments) {
try {
return parseToolCallArguments(arguments, objectMapper);
} catch (Exception e) {
throw new IllegalArgumentException(
"Failed to parse function arguments JSON: " + arguments, e);
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.google.adk.JsonBaseModel;
import com.google.adk.models.LlmRequest;
import com.google.common.collect.ImmutableList;
Expand All @@ -32,6 +36,8 @@
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Part;
import com.google.genai.types.Type;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
Expand Down Expand Up @@ -270,7 +276,28 @@ public final class ChatCompletionsRequest {
public Map<String, Object> extraBody;

private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsRequest.class);
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();

/**
* Registers a custom serializer to force JSON Schema types to lowercase (e.g., "STRING" ->
* "string"). The genai SDK uses uppercase Enums for schema types, which strict OpenAI-compatible
* endpoints reject with HTTP 400.
*/
private static SimpleModule schemaNormalizerModule() {
SimpleModule module = new SimpleModule();
module.addSerializer(
Type.class,
new JsonSerializer<Type>() {
@Override
public void serialize(Type value, JsonGenerator gen, SerializerProvider serializers)
throws IOException {
gen.writeString(value.toString().toLowerCase());
}
});
return module;
}

private static final ObjectMapper objectMapper =
JsonBaseModel.getMapper().copy().registerModule(schemaNormalizerModule());

/**
* Converts a standard {@link LlmRequest} into a {@link ChatCompletionsRequest} for
Expand Down Expand Up @@ -476,7 +503,10 @@ private static ChatCompletionsCommon.ToolCall processFunctionCallPart(Part part)
function.arguments = objectMapper.writeValueAsString(fc.args().get());
} catch (Exception e) {
logger.warn("Failed to serialize function arguments", e);
function.arguments = ChatCompletionsCommon.EMPTY_JSON_OBJECT;
}
} else {
function.arguments = ChatCompletionsCommon.EMPTY_JSON_OBJECT;
}
toolCall.function = function;
part.thoughtSignature()
Expand Down Expand Up @@ -505,7 +535,10 @@ private static Message processFunctionResponsePart(Part part) {
toolResp.content = new MessageContent(objectMapper.writeValueAsString(fr.response().get()));
} catch (Exception e) {
logger.warn("Failed to serialize tool response", e);
toolResp.content = new MessageContent(ChatCompletionsCommon.EMPTY_JSON_OBJECT);
}
} else {
toolResp.content = new MessageContent(ChatCompletionsCommon.EMPTY_JSON_OBJECT);
}
return toolResp;
}
Expand Down Expand Up @@ -570,12 +603,15 @@ private static void handleTools(GenerateContentConfig config, ChatCompletionsReq
FunctionDefinition def = new FunctionDefinition();
def.name = fd.name().orElse("");
def.description = fd.description().orElse("");
fd.parameters()
.ifPresent(
params ->
def.parameters =
objectMapper.convertValue(
params, new TypeReference<Map<String, Object>>() {}));
if (fd.parameters().isPresent()) {
def.parameters =
objectMapper.convertValue(
fd.parameters().get(), new TypeReference<Map<String, Object>>() {});
} else {
// OpenAI-compatible APIs (like Groq) strictly require the parameters object
// to exist, even for zero-argument functions.
def.parameters = ChatCompletionsCommon.EMPTY_PARAMETERS_SCHEMA;
}
tool.function = def;
tools.add(tool);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.models.LlmResponse;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -836,8 +835,7 @@ private ImmutableList<Part> getFinalToolCallParts() {
if (argsSb != null && argsSb.length() > 0) {
try {
Map<String, Object> args =
objectMapper.readValue(
argsSb.toString(), new TypeReference<Map<String, Object>>() {});
ChatCompletionsCommon.parseToolCallArguments(argsSb.toString(), objectMapper);
fc = fc.toBuilder().args(args).build();
part = part.toBuilder().functionCall(fc).build();
} catch (JsonProcessingException e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.google.adk.models.chat;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class ChatCompletionsCommonTest {

private ObjectMapper objectMapper;

@Before
public void setUp() {
objectMapper = new ObjectMapper();
}

@Test
public void parseToolCallArguments_withValidJson() throws Exception {
String json = "{\"pr_number\": 1042, \"reason\": \"review\"}";
ImmutableMap<String, Object> args =
ChatCompletionsCommon.parseToolCallArguments(json, objectMapper);
assertThat(args).hasSize(2);
assertThat(args.get("pr_number")).isEqualTo(1042);
assertThat(args.get("reason")).isEqualTo("review");
assertThat(args).isInstanceOf(ImmutableMap.class);
}

@Test
public void parseToolCallArguments_withEmptyString() throws Exception {
Map<String, Object> args = ChatCompletionsCommon.parseToolCallArguments("", objectMapper);
assertThat(args).isEmpty();
}

@Test
public void parseToolCallArguments_withNullString() throws Exception {
Map<String, Object> args = ChatCompletionsCommon.parseToolCallArguments(null, objectMapper);
assertThat(args).isEmpty();
}

@Test
public void parseToolCallArguments_withWhitespaceString() throws Exception {
Map<String, Object> args = ChatCompletionsCommon.parseToolCallArguments(" ", objectMapper);
assertThat(args).isEmpty();
}

@Test
public void parseToolCallArguments_withInvalidJson_throwsException() {
assertThrows(
JsonProcessingException.class,
() -> ChatCompletionsCommon.parseToolCallArguments("none", objectMapper));

assertThrows(
JsonProcessingException.class,
() -> ChatCompletionsCommon.parseToolCallArguments("{bad_json:", objectMapper));
}

@Test
public void parseToolCallArguments_withLiteralNullString_throwsException() {
JsonProcessingException exception =
assertThrows(
JsonProcessingException.class,
() -> ChatCompletionsCommon.parseToolCallArguments("null", objectMapper));
assertThat(exception)
.hasMessageThat()
.contains("JSON literal 'null' is not a valid JSON object for tool call arguments");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import com.google.genai.types.FunctionResponse;
import com.google.genai.types.GenerateContentConfig;
import com.google.genai.types.Part;
import com.google.genai.types.Schema;
import com.google.genai.types.Tool;
import com.google.genai.types.ToolConfig;
import java.util.AbstractMap;
Expand Down Expand Up @@ -567,6 +568,84 @@ public void testFromLlmRequest_withFunctionCall() throws Exception {
assertThat(msg.toolCalls.get(0).function.arguments).isEqualTo("{\"location\":\"Paris\"}");
}

@Test
public void testFromLlmRequest_withAbsentFunctionArguments() throws Exception {
FunctionCall functionCall = FunctionCall.builder().id("call_123").name("get_time").build();
Part part = Part.builder().functionCall(functionCall).build();
Content content = Content.builder().role("model").parts(ImmutableList.of(part)).build();

LlmRequest llmRequest =
LlmRequest.builder().model("gemini-1.5-pro").contents(ImmutableList.of(content)).build();

ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);

assertThat(request.messages).hasSize(1);
ChatCompletionsRequest.Message msg = request.messages.get(0);
assertThat(msg.role).isEqualTo("assistant");
assertThat(msg.toolCalls).hasSize(1);
assertThat(msg.toolCalls.get(0).function.name).isEqualTo("get_time");
assertThat(msg.toolCalls.get(0).function.arguments).isEqualTo("{}");
}

@Test
public void testFromLlmRequest_withAbsentParameters() throws Exception {
FunctionDeclaration function =
FunctionDeclaration.builder().name("test_func").description("A test function").build();

Tool tool = Tool.builder().functionDeclarations(ImmutableList.of(function)).build();
GenerateContentConfig config =
GenerateContentConfig.builder().tools(ImmutableList.of(tool)).build();

LlmRequest llmRequest =
LlmRequest.builder()
.model("gemini-1.5-pro")
.config(config)
.contents(ImmutableList.of())
.build();

ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);

assertThat(request.tools).hasSize(1);
Map<String, Object> params = (Map<String, Object>) request.tools.get(0).function.parameters;
assertThat(params.get("type")).isEqualTo("object");
@SuppressWarnings("unchecked")
Map<String, Object> props = (Map<String, Object>) params.get("properties");
assertThat(props).isEmpty();
}

@Test
public void testFromLlmRequest_normalizesSchemaTypeToLowerCase() throws Exception {
Schema param1Schema = Schema.builder().type("STRING").build();

Schema functionSchema =
Schema.builder().type("OBJECT").properties(ImmutableMap.of("param1", param1Schema)).build();

FunctionDeclaration function =
FunctionDeclaration.builder().name("test_func").parameters(functionSchema).build();

Tool tool = Tool.builder().functionDeclarations(ImmutableList.of(function)).build();
GenerateContentConfig config =
GenerateContentConfig.builder().tools(ImmutableList.of(tool)).build();

LlmRequest llmRequest =
LlmRequest.builder()
.model("gemini-1.5-pro")
.config(config)
.contents(ImmutableList.of())
.build();

ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);

assertThat(request.tools).hasSize(1);
Map<String, Object> params = (Map<String, Object>) request.tools.get(0).function.parameters;
assertThat(params.get("type")).isEqualTo("object");
@SuppressWarnings("unchecked")
Map<String, Object> props = (Map<String, Object>) params.get("properties");
@SuppressWarnings("unchecked")
Map<String, Object> param1 = (Map<String, Object>) props.get("param1");
assertThat(param1.get("type")).isEqualTo("string");
}

@Test
public void testFromLlmRequest_withStreamOptions() throws Exception {
LlmRequest llmRequest =
Expand Down Expand Up @@ -628,11 +707,11 @@ public void testFromLlmRequest_withFunctionResponse() throws Exception {

assertThat(request.messages.get(1).role).isEqualTo("tool");
assertThat(request.messages.get(1).toolCallId).isEmpty();
assertThat(request.messages.get(1).content).isNull();
assertThat(request.messages.get(1).content.getValue()).isEqualTo("{}");

assertThat(request.messages.get(2).role).isEqualTo("tool");
assertThat(request.messages.get(2).toolCallId).isEqualTo("call_faulty");
assertThat(request.messages.get(2).content).isNull();
assertThat(request.messages.get(2).content.getValue()).isEqualTo("{}");
}

@Test
Expand Down
Loading