diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 379aa6a7..aedb0650 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -83,19 +83,8 @@ import Foundation // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters let generateParameters = toGenerateParameters(options) - // Build chat history starting with system message if instructions are present - var chat: [MLXLMCommon.Chat.Message] = [] - - // Add system message if instructions are present - if let instructionSegments = extractInstructionSegments(from: session) { - let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments) - chat.append(systemMessage) - } - - // Add user prompt - let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description) - let userMessage = convertSegmentsToMLXMessage(userSegments) - chat.append(userMessage) + // Build chat history from full transcript + var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] @@ -208,80 +197,80 @@ import Foundation ) } - // MARK: - Segment Extraction + // MARK: - Transcript Conversion - private func extractPromptSegments(from session: LanguageModelSession, fallbackText: String) -> [Transcript.Segment] - { - // Prefer the most recent Transcript.Prompt entry if present - for entry in session.transcript.reversed() { - if case .prompt(let p) = entry { - return p.segments - } + private func convertTranscriptToMLXChat( + session: LanguageModelSession, + fallbackPrompt: String + ) -> [MLXLMCommon.Chat.Message] { + var chat: [MLXLMCommon.Chat.Message] = [] + + // Check if instructions are already in transcript + let hasInstructionsInTranscript = session.transcript.contains { + if case .instructions = $0 { return true } + return false } - return [.text(.init(content: fallbackText))] - } - private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? { - // Prefer the first Transcript.Instructions entry if present + // Add instructions from session if present and not in transcript + if !hasInstructionsInTranscript, + let instructions = session.instructions?.description, + !instructions.isEmpty + { + chat.append(.init(role: .system, content: instructions)) + } + + // Convert each transcript entry for entry in session.transcript { - if case .instructions(let i) = entry { - return i.segments + switch entry { + case .instructions(let instr): + chat.append(makeMLXChatMessage(from: instr.segments, role: .system)) + + case .prompt(let prompt): + chat.append(makeMLXChatMessage(from: prompt.segments, role: .user)) + + case .response(let response): + let content = response.segments.map { extractText(from: $0) }.joined(separator: "\n") + chat.append(.assistant(content)) + + case .toolCalls: + // Tool calls are handled inline during generation loop + break + + case .toolOutput(let toolOutput): + let content = toolOutput.segments.map { extractText(from: $0) }.joined(separator: "\n") + chat.append(.tool(content)) } } - // Fallback to session.instructions - if let instructions = session.instructions?.description, !instructions.isEmpty { - return [.text(.init(content: instructions))] + + // If no user message in transcript, add fallback prompt + let hasUserMessage = chat.contains { $0.role == .user } + if !hasUserMessage { + chat.append(.init(role: .user, content: fallbackPrompt)) } - return nil - } - private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { - var textParts: [String] = [] - var images: [MLXLMCommon.UserInput.Image] = [] + return chat + } - for segment in segments { - switch segment { - case .text(let text): - textParts.append(text.content) - case .structure(let structured): - textParts.append(structured.content.jsonString) - case .image(let imageSegment): - switch imageSegment.source { - case .url(let url): - images.append(.url(url)) - case .data(let data, _): - #if canImport(UIKit) - if let uiImage = UIKit.UIImage(data: data), - let ciImage = CIImage(image: uiImage) - { - images.append(.ciImage(ciImage)) - } - #elseif canImport(AppKit) - if let nsImage = AppKit.NSImage(data: data), - let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) - { - let ciImage = CIImage(cgImage: cgImage) - images.append(.ciImage(ciImage)) - } - #endif - } - } + private func extractText(from segment: Transcript.Segment) -> String { + switch segment { + case .text(let text): + return text.content + case .structure(let structured): + return structured.content.jsonString + case .image: + return "" } - - let content = textParts.joined(separator: "\n") - return MLXLMCommon.Chat.Message(role: .user, content: content, images: images) } - private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { + private func makeMLXChatMessage( + from segments: [Transcript.Segment], + role: MLXLMCommon.Chat.Message.Role + ) -> MLXLMCommon.Chat.Message { var textParts: [String] = [] var images: [MLXLMCommon.UserInput.Image] = [] for segment in segments { switch segment { - case .text(let text): - textParts.append(text.content) - case .structure(let structured): - textParts.append(structured.content.jsonString) case .image(let imageSegment): switch imageSegment.source { case .url(let url): @@ -302,11 +291,16 @@ import Foundation } #endif } + default: + let text = extractText(from: segment) + if !text.isEmpty { + textParts.append(text) + } } } let content = textParts.joined(separator: "\n") - return MLXLMCommon.Chat.Message(role: .system, content: content, images: images) + return MLXLMCommon.Chat.Message(role: role, content: content, images: images) } // MARK: - Tool Conversion