From 8a5b59cf7bd6c520fa218de87c4fc73f624631b5 Mon Sep 17 00:00:00 2001 From: noorbhatia Date: Wed, 17 Dec 2025 13:35:54 +0530 Subject: [PATCH 1/4] Refactor chat history construction to utilize full transcript conversion in MLXLanguageModel --- .../Models/MLXLanguageModel.swift | 93 ++++++++++++------- 1 file changed, 62 insertions(+), 31 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 379aa6a7..d3e7ab2f 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -83,19 +83,8 @@ import Foundation // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters let generateParameters = toGenerateParameters(options) - // Build chat history starting with system message if instructions are present - var chat: [MLXLMCommon.Chat.Message] = [] - - // Add system message if instructions are present - if let instructionSegments = extractInstructionSegments(from: session) { - let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments) - chat.append(systemMessage) - } - - // Add user prompt - let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description) - let userMessage = convertSegmentsToMLXMessage(userSegments) - chat.append(userMessage) + // Build chat history from full transcript + var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] @@ -208,31 +197,73 @@ import Foundation ) } - // MARK: - Segment Extraction + // MARK: - Transcript Conversion - private func extractPromptSegments(from session: LanguageModelSession, fallbackText: String) -> [Transcript.Segment] - { - // Prefer the most recent Transcript.Prompt entry if present - for entry in session.transcript.reversed() { - if case .prompt(let p) = entry { - return p.segments - } + /// Converts the full session transcript into MLX chat messages. + private func convertTranscriptToMLXChat( + session: LanguageModelSession, + fallbackPrompt: String + ) -> [MLXLMCommon.Chat.Message] { + var chat: [MLXLMCommon.Chat.Message] = [] + + // Check if instructions are already in transcript + let hasInstructionsInTranscript = session.transcript.contains { + if case .instructions = $0 { return true } + return false + } + + // Add instructions from session if present and not in transcript + if !hasInstructionsInTranscript, + let instructions = session.instructions?.description, + !instructions.isEmpty + { + chat.append(.init(role: .system, content: instructions)) } - return [.text(.init(content: fallbackText))] - } - private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? { - // Prefer the first Transcript.Instructions entry if present + // Convert each transcript entry for entry in session.transcript { - if case .instructions(let i) = entry { - return i.segments + switch entry { + case .instructions(let instr): + let message = convertSegmentsToMLXSystemMessage(instr.segments) + chat.append(message) + + case .prompt(let prompt): + let message = convertSegmentsToMLXMessage(prompt.segments) + chat.append(message) + + case .response(let response): + let content = response.segments.map { segmentToText($0) }.joined(separator: "\n") + chat.append(.assistant(content)) + + case .toolCalls: + // Tool calls are handled inline during generation loop + break + + case .toolOutput(let toolOutput): + let content = toolOutput.segments.map { segmentToText($0) }.joined(separator: "\n") + chat.append(.tool(content)) } } - // Fallback to session.instructions - if let instructions = session.instructions?.description, !instructions.isEmpty { - return [.text(.init(content: instructions))] + + // If no user message in transcript, add fallback prompt + let hasUserMessage = chat.contains { $0.role == .user } + if !hasUserMessage { + chat.append(.init(role: .user, content: fallbackPrompt)) + } + + return chat + } + + /// Extracts text content from a transcript segment. + private func segmentToText(_ segment: Transcript.Segment) -> String { + switch segment { + case .text(let text): + return text.content + case .structure(let structured): + return structured.content.jsonString + case .image: + return "" } - return nil } private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { From 7032dd08ded27e4bbe401c80263120f55397d534 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 17 Dec 2025 01:51:42 -0800 Subject: [PATCH 2/4] Rename convertSegmentsToMLXMessage to convertSegmentsToMLXUserMessage --- Sources/AnyLanguageModel/Models/MLXLanguageModel.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index d3e7ab2f..6a52de53 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -228,7 +228,7 @@ import Foundation chat.append(message) case .prompt(let prompt): - let message = convertSegmentsToMLXMessage(prompt.segments) + let message = convertSegmentsToMLXUserMessage(prompt.segments) chat.append(message) case .response(let response): @@ -266,7 +266,7 @@ import Foundation } } - private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { + private func convertSegmentsToMLXUserMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { var textParts: [String] = [] var images: [MLXLMCommon.UserInput.Image] = [] From 3d87e162e03a0e21824cd5470c6333e989e2fa64 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 17 Dec 2025 01:54:23 -0800 Subject: [PATCH 3/4] DRY up creation of MLX messages from segments --- .../Models/MLXLanguageModel.swift | 50 +++---------------- 1 file changed, 7 insertions(+), 43 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 6a52de53..9b7f7365 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -224,12 +224,10 @@ import Foundation for entry in session.transcript { switch entry { case .instructions(let instr): - let message = convertSegmentsToMLXSystemMessage(instr.segments) - chat.append(message) + chat.append(makeMLXChatMessage(from: instr.segments, role: .system)) case .prompt(let prompt): - let message = convertSegmentsToMLXUserMessage(prompt.segments) - chat.append(message) + chat.append(makeMLXChatMessage(from: prompt.segments, role: .user)) case .response(let response): let content = response.segments.map { segmentToText($0) }.joined(separator: "\n") @@ -266,7 +264,10 @@ import Foundation } } - private func convertSegmentsToMLXUserMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { + private func makeMLXChatMessage( + from segments: [Transcript.Segment], + role: MLXLMCommon.Chat.Message.Role + ) -> MLXLMCommon.Chat.Message { var textParts: [String] = [] var images: [MLXLMCommon.UserInput.Image] = [] @@ -300,44 +301,7 @@ import Foundation } let content = textParts.joined(separator: "\n") - return MLXLMCommon.Chat.Message(role: .user, content: content, images: images) - } - - private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { - var textParts: [String] = [] - var images: [MLXLMCommon.UserInput.Image] = [] - - for segment in segments { - switch segment { - case .text(let text): - textParts.append(text.content) - case .structure(let structured): - textParts.append(structured.content.jsonString) - case .image(let imageSegment): - switch imageSegment.source { - case .url(let url): - images.append(.url(url)) - case .data(let data, _): - #if canImport(UIKit) - if let uiImage = UIKit.UIImage(data: data), - let ciImage = CIImage(image: uiImage) - { - images.append(.ciImage(ciImage)) - } - #elseif canImport(AppKit) - if let nsImage = AppKit.NSImage(data: data), - let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) - { - let ciImage = CIImage(cgImage: cgImage) - images.append(.ciImage(ciImage)) - } - #endif - } - } - } - - let content = textParts.joined(separator: "\n") - return MLXLMCommon.Chat.Message(role: .system, content: content, images: images) + return MLXLMCommon.Chat.Message(role: role, content: content, images: images) } // MARK: - Tool Conversion From 1b0db52ad81924307f5c40466f996f065cebb168 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Wed, 17 Dec 2025 02:18:41 -0800 Subject: [PATCH 4/4] Rename segmentToText to extractText and refactor makeMLXChatMessage to use it --- .../Models/MLXLanguageModel.swift | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 9b7f7365..aedb0650 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -199,7 +199,6 @@ import Foundation // MARK: - Transcript Conversion - /// Converts the full session transcript into MLX chat messages. private func convertTranscriptToMLXChat( session: LanguageModelSession, fallbackPrompt: String @@ -230,7 +229,7 @@ import Foundation chat.append(makeMLXChatMessage(from: prompt.segments, role: .user)) case .response(let response): - let content = response.segments.map { segmentToText($0) }.joined(separator: "\n") + let content = response.segments.map { extractText(from: $0) }.joined(separator: "\n") chat.append(.assistant(content)) case .toolCalls: @@ -238,7 +237,7 @@ import Foundation break case .toolOutput(let toolOutput): - let content = toolOutput.segments.map { segmentToText($0) }.joined(separator: "\n") + let content = toolOutput.segments.map { extractText(from: $0) }.joined(separator: "\n") chat.append(.tool(content)) } } @@ -252,8 +251,7 @@ import Foundation return chat } - /// Extracts text content from a transcript segment. - private func segmentToText(_ segment: Transcript.Segment) -> String { + private func extractText(from segment: Transcript.Segment) -> String { switch segment { case .text(let text): return text.content @@ -273,10 +271,6 @@ import Foundation for segment in segments { switch segment { - case .text(let text): - textParts.append(text.content) - case .structure(let structured): - textParts.append(structured.content.jsonString) case .image(let imageSegment): switch imageSegment.source { case .url(let url): @@ -297,6 +291,11 @@ import Foundation } #endif } + default: + let text = extractText(from: segment) + if !text.isEmpty { + textParts.append(text) + } } }