Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,74 @@ import Foundation
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
// For now, only String is supported
guard type == String.self else {
fatalError("MLXLanguageModel only supports generating String content")
}

// Streaming API in AnyLanguageModel currently yields once; return an empty snapshot
let empty = ""
return LanguageModelSession.ResponseStream(
content: empty as! Content,
rawContent: GeneratedContent(empty)
)
let modelId = self.modelId
let hub = self.hub
let directory = self.directory

let stream: AsyncThrowingStream<LanguageModelSession.ResponseStream<Content>.Snapshot, any Error> = .init { continuation in
let task = Task { @Sendable in
do {
let context: ModelContext
if let directory {
context = try await loadModel(directory: directory)
} else if let hub {
context = try await loadModel(hub: hub, id: modelId)
} else {
context = try await loadModel(id: modelId)
}

let generateParameters = toGenerateParameters(options)

var chat: [MLXLMCommon.Chat.Message] = []

if let instructionSegments = extractInstructionSegments(from: session) {
chat.append(convertSegmentsToMLXSystemMessage(instructionSegments))
}

let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
chat.append(convertSegmentsToMLXMessage(userSegments))

let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: nil
)
let lmInput = try await context.processor.prepare(input: userInput)

let mlxStream = try MLXLMCommon.generate(
input: lmInput,
parameters: generateParameters,
context: context
)

var accumulatedText = ""
for await item in mlxStream {
if Task.isCancelled { break }

switch item {
case .chunk(let text):
accumulatedText += text
let raw = GeneratedContent(accumulatedText)
let content: Content.PartiallyGenerated = (accumulatedText as! Content).asPartiallyGenerated()
continuation.yield(.init(content: content, rawContent: raw))
case .info, .toolCall:
break
}
}

continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
continuation.onTermination = { _ in task.cancel() }
}

return LanguageModelSession.ResponseStream(stream: stream)
}
}

Expand Down
15 changes: 14 additions & 1 deletion Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import Testing
return false
}()

@Suite("MLXLanguageModel", .enabled(if: shouldRunMLXTests))
@Suite("MLXLanguageModel", .enabled(if: shouldRunMLXTests), .serialized)
struct MLXLanguageModelTests {
// Qwen3-0.6B is a small model that supports tool calling
let model = MLXLanguageModel(modelId: "mlx-community/Qwen3-0.6B-4bit")
Expand All @@ -42,6 +42,19 @@ import Testing
#expect(!response.content.isEmpty)
}

@Test func streamingResponse() async throws {
let session = LanguageModelSession(model: model)

let stream = session.streamResponse(to: "Count to 5")
var chunks: [String] = []

for try await response in stream {
chunks.append(response.content)
}

#expect(!chunks.isEmpty)
}

@Test func withGenerationOptions() async throws {
let session = LanguageModelSession(model: model)

Expand Down