From 4da015a7dcfadc2ce5b7a027af009eb4088ebcf3 Mon Sep 17 00:00:00 2001 From: Michael Keller Date: Thu, 6 Jun 2024 21:51:41 +0200 Subject: [PATCH] Add streaming response (text only) as discussed in #625 --- clients/dotnet/WebClient/MemoryWebClient.cs | 12 ++ service/Abstractions/Constants.cs | 1 + service/Abstractions/IKernelMemory.cs | 18 +++ service/Abstractions/Search/ISearchClient.cs | 16 ++ service/Core/MemoryServerless.cs | 25 +++ service/Core/MemoryService.cs | 25 +++ service/Core/Search/SearchClient.cs | 144 ++++++++++++------ service/Service.AspNetCore/WebAPIEndpoints.cs | 36 +++++ 8 files changed, 228 insertions(+), 49 deletions(-) diff --git a/clients/dotnet/WebClient/MemoryWebClient.cs b/clients/dotnet/WebClient/MemoryWebClient.cs index a56aa3eb2..3d71591ec 100644 --- a/clients/dotnet/WebClient/MemoryWebClient.cs +++ b/clients/dotnet/WebClient/MemoryWebClient.cs @@ -358,6 +358,18 @@ public async Task AskAsync( return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer(); } + /// + public IAsyncEnumerable AskTextStreaming( + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + CancellationToken cancellationToken = default) + { + throw new NotImplementedException("Streaming text responses is not supported by the Kernel Memory web service"); + } + #region private private static (string contentType, long contentLength, DateTimeOffset lastModified) GetFileDetails(HttpResponseMessage response) diff --git a/service/Abstractions/Constants.cs b/service/Abstractions/Constants.cs index 820e4af20..7e4b7f75b 100644 --- a/service/Abstractions/Constants.cs +++ b/service/Abstractions/Constants.cs @@ -52,6 +52,7 @@ public static class Constants // Endpoints public const string HttpAskEndpoint = "/ask"; + public const string HttpAskTextStreamingEndpoint = "/askstreaming"; public const string HttpSearchEndpoint = "/search"; public const string HttpDownloadEndpoint = "/download"; public const string HttpUploadEndpoint = "/upload"; diff --git a/service/Abstractions/IKernelMemory.cs b/service/Abstractions/IKernelMemory.cs index 00ce78632..7f07051ea 100644 --- a/service/Abstractions/IKernelMemory.cs +++ b/service/Abstractions/IKernelMemory.cs @@ -211,4 +211,22 @@ public Task AskAsync( ICollection? filters = null, double minRelevance = 0, CancellationToken cancellationToken = default); + + /// + /// Search the given index for an answer to the given query. + /// + /// Question to answer + /// Optional index name + /// Filter to match + /// Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list. + /// Minimum Cosine Similarity required + /// Async task cancellation token + /// Answer to the query, if possible + public IAsyncEnumerable AskTextStreaming( + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + CancellationToken cancellationToken = default); } diff --git a/service/Abstractions/Search/ISearchClient.cs b/service/Abstractions/Search/ISearchClient.cs index 9925ba052..2350de2fa 100644 --- a/service/Abstractions/Search/ISearchClient.cs +++ b/service/Abstractions/Search/ISearchClient.cs @@ -45,6 +45,22 @@ Task AskAsync( double minRelevance = 0, CancellationToken cancellationToken = default); + /// + /// Answer the given question, if possible, grounding the response with relevant memories matching the given criteria. + /// + /// Index (aka collection) to search for grounding information + /// Question to answer + /// Filtering criteria to select memories to consider + /// Minimum relevance of the memories considered + /// Async task cancellation token + /// Answer to the given question as a stream (IAsyncEnumerable) + IAsyncEnumerable AskTextStreamingAsync( + string index, + string question, + ICollection? filters = null, + double minRelevance = 0, + CancellationToken cancellationToken = default); + /// /// List the available memory indexes (aka collections). /// diff --git a/service/Core/MemoryServerless.cs b/service/Core/MemoryServerless.cs index d8f9ff78b..3f644ace8 100644 --- a/service/Core/MemoryServerless.cs +++ b/service/Core/MemoryServerless.cs @@ -257,4 +257,29 @@ public Task AskAsync( minRelevance: minRelevance, cancellationToken: cancellationToken); } + + /// + public IAsyncEnumerable AskTextStreaming( + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + CancellationToken cancellationToken = default) + { + if (filter != null) + { + if (filters == null) { filters = new List(); } + + filters.Add(filter); + } + + index = IndexName.CleanName(index, this._defaultIndexName); + return this._searchClient.AskTextStreamingAsync( + index: index, + question: question, + filters: filters, + minRelevance: minRelevance, + cancellationToken: cancellationToken); + } } diff --git a/service/Core/MemoryService.cs b/service/Core/MemoryService.cs index 818e3d6c4..ab2f0501b 100644 --- a/service/Core/MemoryService.cs +++ b/service/Core/MemoryService.cs @@ -234,4 +234,29 @@ public Task AskAsync( minRelevance: minRelevance, cancellationToken: cancellationToken); } + + /// + public IAsyncEnumerable AskTextStreaming( + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + CancellationToken cancellationToken = default) + { + if (filter != null) + { + if (filters == null) { filters = new List(); } + + filters.Add(filter); + } + + index = IndexName.CleanName(index, this._defaultIndexName); + return this._searchClient.AskTextStreamingAsync( + index: index, + question: question, + filters: filters, + minRelevance: minRelevance, + cancellationToken: cancellationToken); + } } diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 8a7eaf269..1da7421da 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -201,6 +202,97 @@ public async Task AskAsync( return noAnswerFound; } + var (facts, relevantSources, factsAvailableCount, factsUsedCount) = await this.GetFactsAsync(question, index, filters, minRelevance, cancellationToken).ConfigureAwait(false); + var answer = noAnswerFound; + answer.RelevantSources = relevantSources; + if (factsAvailableCount > 0 && factsUsedCount == 0) + { + this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); + noAnswerFound.NoResultReason = "Unable to use memories"; + return noAnswerFound; + } + + if (factsUsedCount == 0) + { + this._log.LogWarning("No memories available"); + noAnswerFound.NoResultReason = "No memories available"; + return noAnswerFound; + } + + var text = new StringBuilder(); + var charsGenerated = 0; + var watch = new Stopwatch(); + watch.Restart(); + await foreach (var x in this.GenerateAnswer(question, facts.ToString()) + .WithCancellation(cancellationToken).ConfigureAwait(false)) + { + text.Append(x); + + if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) + { + charsGenerated = text.Length; + this._log.LogTrace("{0} chars generated", charsGenerated); + } + } + + watch.Stop(); + + answer.Result = text.ToString(); + answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer); + if (answer.NoResult) + { + answer.NoResultReason = "No relevant memories found"; + this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds); + } + else + { + this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds); + } + + return answer; + } + + /// + public async IAsyncEnumerable AskTextStreamingAsync(string index, string question, ICollection? filters = null, double minRelevance = 0, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (string.IsNullOrEmpty(question)) + { + this._log.LogWarning("No question provided"); + yield return this._config.EmptyAnswer; + yield break; + } + + var (facts, _, _, factsUsedCount) = await this.GetFactsAsync(question, index, filters, minRelevance, cancellationToken).ConfigureAwait(false); + + if (factsUsedCount == 0) + { + this._log.LogError("No memories available or unable to inject memories in the prompt, not enough tokens available"); + yield return this._config.EmptyAnswer; + yield break; + } + + var text = new StringBuilder(); + var charsGenerated = 0; + var watch = new Stopwatch(); + watch.Restart(); + await foreach (var x in this.GenerateAnswer(question, facts.ToString()) + .WithCancellation(cancellationToken).ConfigureAwait(false)) + { + yield return x; + + if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) + { + charsGenerated = text.Length; + this._log.LogTrace("{0} chars generated", charsGenerated); + } + } + + watch.Stop(); + } + + private async Task<(string, List, int, int)> GetFactsAsync(string question, string index, ICollection? filters, double minRelevance, CancellationToken cancellationToken) + { + var relevantSources = new List(); var facts = new StringBuilder(); var maxTokens = this._config.MaxAskPromptSize > 0 ? this._config.MaxAskPromptSize @@ -212,7 +304,6 @@ public async Task AskAsync( var factsUsedCount = 0; var factsAvailableCount = 0; - var answer = noAnswerFound; this._log.LogTrace("Fetching relevant memories"); IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync( @@ -267,11 +358,11 @@ public async Task AskAsync( tokensAvailable -= size; // If the file is already in the list of citations, only add the partition - var citation = answer.RelevantSources.FirstOrDefault(x => x.Link == linkToFile); + var citation = relevantSources.FirstOrDefault(x => x.Link == linkToFile); if (citation == null) { citation = new Citation(); - answer.RelevantSources.Add(citation); + relevantSources.Add(citation); } // Add the partition to the list of citations @@ -299,52 +390,7 @@ public async Task AskAsync( break; } } - - if (factsAvailableCount > 0 && factsUsedCount == 0) - { - this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); - noAnswerFound.NoResultReason = "Unable to use memories"; - return noAnswerFound; - } - - if (factsUsedCount == 0) - { - this._log.LogWarning("No memories available"); - noAnswerFound.NoResultReason = "No memories available"; - return noAnswerFound; - } - - var text = new StringBuilder(); - var charsGenerated = 0; - var watch = new Stopwatch(); - watch.Restart(); - await foreach (var x in this.GenerateAnswer(question, facts.ToString()) - .WithCancellation(cancellationToken).ConfigureAwait(false)) - { - text.Append(x); - - if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) - { - charsGenerated = text.Length; - this._log.LogTrace("{0} chars generated", charsGenerated); - } - } - - watch.Stop(); - - answer.Result = text.ToString(); - answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer); - if (answer.NoResult) - { - answer.NoResultReason = "No relevant memories found"; - this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds); - } - else - { - this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds); - } - - return answer; + return (facts.ToString(), relevantSources, factsAvailableCount, factsUsedCount); } private IAsyncEnumerable GenerateAnswer(string question, string facts) diff --git a/service/Service.AspNetCore/WebAPIEndpoints.cs b/service/Service.AspNetCore/WebAPIEndpoints.cs index 0d6d5e5c6..afd6db54a 100644 --- a/service/Service.AspNetCore/WebAPIEndpoints.cs +++ b/service/Service.AspNetCore/WebAPIEndpoints.cs @@ -13,6 +13,7 @@ using System.IO; using Microsoft.AspNetCore.Http.HttpResults; using Microsoft.KernelMemory.DocumentStorage; +using System.Runtime.CompilerServices; namespace Microsoft.KernelMemory.Service.AspNetCore; @@ -28,6 +29,7 @@ public static IEndpointRouteBuilder AddKernelMemoryEndpoints( builder.AddDeleteIndexesEndpoint(apiPrefix, authFilter); builder.AddDeleteDocumentsEndpoint(apiPrefix, authFilter); builder.AddAskEndpoint(apiPrefix, authFilter); + builder.AddAskTextStreamingEndpoint(apiPrefix, authFilter); builder.AddSearchEndpoint(apiPrefix, authFilter); builder.AddUploadStatusEndpoint(apiPrefix, authFilter); builder.AddGetDownloadEndpoint(apiPrefix, authFilter); @@ -224,6 +226,40 @@ async Task ( if (authFilter != null) { route.AddEndpointFilter(authFilter); } } + public static void AddAskTextStreamingEndpoint( + this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null) + { + RouteGroupBuilder group = builder.MapGroup(apiPrefix); + + // Ask endpoint + var route = group.MapPost(Constants.HttpAskTextStreamingEndpoint, GetStreamAsync) + .Produces(StatusCodes.Status200OK) + .Produces(StatusCodes.Status401Unauthorized) + .Produces(StatusCodes.Status403Forbidden); + + if (authFilter != null) { route.AddEndpointFilter(authFilter); } + } + + private static async IAsyncEnumerable GetStreamAsync( + MemoryQuery query, + IKernelMemory service, + ILogger log, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance); + await foreach (var textStream in service.AskTextStreaming( + question: query.Question, + index: query.Index, + filters: query.Filters, + minRelevance: query.MinRelevance, + cancellationToken: cancellationToken).ConfigureAwait(false)) + { + await Task.Delay(1, cancellationToken).ConfigureAwait(false); // needed to fix issue where this is not streaming + + yield return textStream ?? ""; + } + } + public static void AddSearchEndpoint( this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null) {