Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 65 additions & 71 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,8 @@ import Foundation
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
let generateParameters = toGenerateParameters(options)

// Build chat history starting with system message if instructions are present
var chat: [MLXLMCommon.Chat.Message] = []

// Add system message if instructions are present
if let instructionSegments = extractInstructionSegments(from: session) {
let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments)
chat.append(systemMessage)
}

// Add user prompt
let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description)
let userMessage = convertSegmentsToMLXMessage(userSegments)
chat.append(userMessage)
// Build chat history from full transcript
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

var allTextChunks: [String] = []
var allEntries: [Transcript.Entry] = []
Expand Down Expand Up @@ -208,80 +197,80 @@ import Foundation
)
}

// MARK: - Segment Extraction
// MARK: - Transcript Conversion

private func extractPromptSegments(from session: LanguageModelSession, fallbackText: String) -> [Transcript.Segment]
{
// Prefer the most recent Transcript.Prompt entry if present
for entry in session.transcript.reversed() {
if case .prompt(let p) = entry {
return p.segments
}
private func convertTranscriptToMLXChat(
session: LanguageModelSession,
fallbackPrompt: String
) -> [MLXLMCommon.Chat.Message] {
var chat: [MLXLMCommon.Chat.Message] = []

// Check if instructions are already in transcript
let hasInstructionsInTranscript = session.transcript.contains {
if case .instructions = $0 { return true }
return false
}
return [.text(.init(content: fallbackText))]
}

private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? {
// Prefer the first Transcript.Instructions entry if present
// Add instructions from session if present and not in transcript
if !hasInstructionsInTranscript,
let instructions = session.instructions?.description,
!instructions.isEmpty
{
chat.append(.init(role: .system, content: instructions))
}

// Convert each transcript entry
for entry in session.transcript {
if case .instructions(let i) = entry {
return i.segments
switch entry {
case .instructions(let instr):
chat.append(makeMLXChatMessage(from: instr.segments, role: .system))

case .prompt(let prompt):
chat.append(makeMLXChatMessage(from: prompt.segments, role: .user))

case .response(let response):
let content = response.segments.map { extractText(from: $0) }.joined(separator: "\n")
chat.append(.assistant(content))

case .toolCalls:
// Tool calls are handled inline during generation loop
break

case .toolOutput(let toolOutput):
let content = toolOutput.segments.map { extractText(from: $0) }.joined(separator: "\n")
chat.append(.tool(content))
}
}
// Fallback to session.instructions
if let instructions = session.instructions?.description, !instructions.isEmpty {
return [.text(.init(content: instructions))]

// If no user message in transcript, add fallback prompt
let hasUserMessage = chat.contains { $0.role == .user }
if !hasUserMessage {
chat.append(.init(role: .user, content: fallbackPrompt))
}
return nil
}

private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
var textParts: [String] = []
var images: [MLXLMCommon.UserInput.Image] = []
return chat
}

for segment in segments {
switch segment {
case .text(let text):
textParts.append(text.content)
case .structure(let structured):
textParts.append(structured.content.jsonString)
case .image(let imageSegment):
switch imageSegment.source {
case .url(let url):
images.append(.url(url))
case .data(let data, _):
#if canImport(UIKit)
if let uiImage = UIKit.UIImage(data: data),
let ciImage = CIImage(image: uiImage)
{
images.append(.ciImage(ciImage))
}
#elseif canImport(AppKit)
if let nsImage = AppKit.NSImage(data: data),
let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil)
{
let ciImage = CIImage(cgImage: cgImage)
images.append(.ciImage(ciImage))
}
#endif
}
}
private func extractText(from segment: Transcript.Segment) -> String {
switch segment {
case .text(let text):
return text.content
case .structure(let structured):
return structured.content.jsonString
case .image:
return ""
}

let content = textParts.joined(separator: "\n")
return MLXLMCommon.Chat.Message(role: .user, content: content, images: images)
}

private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message {
private func makeMLXChatMessage(
from segments: [Transcript.Segment],
role: MLXLMCommon.Chat.Message.Role
) -> MLXLMCommon.Chat.Message {
var textParts: [String] = []
var images: [MLXLMCommon.UserInput.Image] = []

for segment in segments {
switch segment {
case .text(let text):
textParts.append(text.content)
case .structure(let structured):
textParts.append(structured.content.jsonString)
case .image(let imageSegment):
switch imageSegment.source {
case .url(let url):
Expand All @@ -302,11 +291,16 @@ import Foundation
}
#endif
}
default:
let text = extractText(from: segment)
if !text.isEmpty {
textParts.append(text)
}
}
}

let content = textParts.joined(separator: "\n")
return MLXLMCommon.Chat.Message(role: .system, content: content, images: images)
return MLXLMCommon.Chat.Message(role: role, content: content, images: images)
}

// MARK: - Tool Conversion
Expand Down