Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(xai): xAI polish #7722

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/langchain-xai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"@langchain/openai": "~0.3.0"
"@langchain/openai": "~0.4.4",
"zod": "^3.24.2"
},
"peerDependencies": {
"@langchain/core": ">=0.2.21 <0.4.0"
Expand Down Expand Up @@ -67,7 +68,6 @@
"rollup": "^4.5.2",
"ts-jest": "^29.1.0",
"typescript": "<5.2.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.23.1"
},
"publishConfig": {
Expand Down
105 changes: 105 additions & 0 deletions libs/langchain-xai/src/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import { BaseLanguageModelInput } from "@langchain/core/language_models/base";
import {
BaseChatModelCallOptions,
BindToolsInput,
LangSmithParams,
type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { Serialized } from "@langchain/core/load/serializable";
import { AIMessageChunk, BaseMessage } from "@langchain/core/messages";
import { Runnable } from "@langchain/core/runnables";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import {
type OpenAICoreRequestOptions,
type OpenAIClient,
ChatOpenAI,
OpenAIToolChoice,
ChatOpenAIStructuredOutputMethodOptions,
} from "@langchain/openai";
import { z } from "zod";

type ChatXAIToolType = BindToolsInput | OpenAIClient.ChatCompletionTool;

Expand Down Expand Up @@ -494,4 +499,104 @@ export class ChatXAI extends ChatOpenAI<ChatXAICallOptions> {

return super.completionWithRetry(newRequest, options);
}

protected override _convertOpenAIDeltaToBaseMessageChunk(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
delta: Record<string, any>,
rawResponse: OpenAIClient.ChatCompletionChunk,
defaultRole?:
| "function"
| "user"
| "system"
| "developer"
| "assistant"
| "tool"
) {
const messageChunk: AIMessageChunk =
super._convertOpenAIDeltaToBaseMessageChunk(
delta,
rawResponse,
defaultRole
);
// Make concatenating chunks work without merge warning
if (!rawResponse.choices[0]?.finish_reason) {
delete messageChunk.response_metadata.usage;
delete messageChunk.usage_metadata;
} else {
messageChunk.usage_metadata = messageChunk.response_metadata.usage;
}
return messageChunk;
}

protected override _convertOpenAIChatCompletionMessageToBaseMessage(
message: OpenAIClient.ChatCompletionMessage,
rawResponse: OpenAIClient.ChatCompletion
) {
const langChainMessage =
super._convertOpenAIChatCompletionMessageToBaseMessage(
message,
rawResponse
);
langChainMessage.additional_kwargs.reasoning_content =
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(message as any).reasoning_content;
return langChainMessage;
}

override withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: ChatOpenAIStructuredOutputMethodOptions<false>
): Runnable<BaseLanguageModelInput, RunOutput>;

override withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: ChatOpenAIStructuredOutputMethodOptions<true>
): Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;

override withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: ChatOpenAIStructuredOutputMethodOptions<boolean>
):
| Runnable<BaseLanguageModelInput, RunOutput>
| Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;

override withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: ChatOpenAIStructuredOutputMethodOptions<boolean>
):
| Runnable<BaseLanguageModelInput, RunOutput>
| Runnable<
BaseLanguageModelInput,
{ raw: BaseMessage; parsed: RunOutput }
> {
const ensuredConfig = { ...config };
if (ensuredConfig?.method === undefined) {
ensuredConfig.method = "functionCalling";
}
return super.withStructuredOutput<RunOutput>(outputSchema, ensuredConfig);
}
}
8 changes: 4 additions & 4 deletions libs/langchain-xai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ test("streaming", async () => {
test("invoke with bound tools", async () => {
const chat = new ChatXAI({
maxRetries: 0,
model: "grok-beta",
model: "grok-2-1212",
});
const message = new HumanMessage("What is the current weather in Hawaii?");
const res = await chat
Expand Down Expand Up @@ -144,7 +144,7 @@ test("stream with bound tools, yielding a single chunk", async () => {

test("Few shotting with tool calls", async () => {
const chat = new ChatXAI({
model: "grok-beta",
model: "grok-2-1212",
temperature: 0,
}).bind({
tools: [
Expand Down Expand Up @@ -194,9 +194,9 @@ test("Few shotting with tool calls", async () => {
expect(res.content).toContain("24");
});

test("Groq can stream tool calls", async () => {
test("xAI can stream tool calls", async () => {
const model = new ChatXAI({
model: "grok-beta",
model: "grok-2-1212",
temperature: 0,
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { ChatXAI } from "../chat_models.js";
test("withStructuredOutput zod schema function calling", async () => {
const model = new ChatXAI({
temperature: 0,
model: "grok-beta",
model: "grok-2-1212",
});

const calculatorSchema = z.object({
Expand Down Expand Up @@ -37,7 +37,7 @@ test("withStructuredOutput zod schema function calling", async () => {
test("withStructuredOutput zod schema JSON mode", async () => {
const model = new ChatXAI({
temperature: 0,
model: "grok-beta",
model: "grok-2-1212",
});

const calculatorSchema = z.object({
Expand Down Expand Up @@ -76,7 +76,7 @@ Respond with a JSON object containing three keys:
test("withStructuredOutput JSON schema function calling", async () => {
const model = new ChatXAI({
temperature: 0,
model: "grok-beta",
model: "grok-2-1212",
});

const calculatorSchema = z.object({
Expand Down Expand Up @@ -106,7 +106,7 @@ test("withStructuredOutput JSON schema function calling", async () => {
test("withStructuredOutput OpenAI function definition function calling", async () => {
const model = new ChatXAI({
temperature: 0,
model: "grok-beta",
model: "grok-2-1212",
});

const calculatorSchema = z.object({
Expand All @@ -120,14 +120,12 @@ test("withStructuredOutput OpenAI function definition function calling", async (
});

const prompt = ChatPromptTemplate.fromMessages([
"system",
`You are VERY bad at math and must always use a calculator.`,
"human",
"Please help me!! What is 2 + 2?",
["system", `You are VERY bad at math and must always use a calculator.`],
["human", "Please help me!! What is 2 + 2?"],
]);
const chain = prompt.pipe(modelWithStructuredOutput);
const result = await chain.invoke({});
// console.log(result);

expect("operation" in result).toBe(true);
expect("number1" in result).toBe(true);
expect("number2" in result).toBe(true);
Expand All @@ -136,7 +134,7 @@ test("withStructuredOutput OpenAI function definition function calling", async (
test("withStructuredOutput JSON schema JSON mode", async () => {
const model = new ChatXAI({
temperature: 0,
model: "grok-beta",
model: "grok-2-1212",
});

const calculatorSchema = z.object({
Expand Down Expand Up @@ -175,7 +173,7 @@ Respond with a JSON object containing three keys:
test("withStructuredOutput JSON schema", async () => {
const model = new ChatXAI({
temperature: 0,
model: "grok-beta",
model: "grok-2-1212",
});

const jsonSchema = {
Expand Down Expand Up @@ -216,7 +214,7 @@ Respond with a JSON object containing three keys:
test("withStructuredOutput includeRaw true", async () => {
const model = new ChatXAI({
temperature: 0,
model: "grok-beta",
model: "grok-2-1212",
});

const calculatorSchema = z.object({
Expand Down
2 changes: 1 addition & 1 deletion yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13335,7 +13335,7 @@ __metadata:
rollup: ^4.5.2
ts-jest: ^29.1.0
typescript: <5.2.0
zod: ^3.22.4
zod: ^3.24.2
zod-to-json-schema: ^3.23.1
peerDependencies:
"@langchain/core": ">=0.2.21 <0.4.0"
Expand Down
Loading