Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions clients/dotnet/WebClient/MemoryWebClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,18 @@ public async Task<MemoryAnswer> AskAsync(
return JsonSerializer.Deserialize<MemoryAnswer>(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer();
}

/// <inheritdoc />
public IAsyncEnumerable<string> AskTextStreaming(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default)
{
throw new NotImplementedException("Streaming text responses is not supported by the Kernel Memory web service");
}

#region private

private static (string contentType, long contentLength, DateTimeOffset lastModified) GetFileDetails(HttpResponseMessage response)
Expand Down
1 change: 1 addition & 0 deletions service/Abstractions/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public static class Constants

// Endpoints
public const string HttpAskEndpoint = "/ask";
public const string HttpAskTextStreamingEndpoint = "/askstreaming";
public const string HttpSearchEndpoint = "/search";
public const string HttpDownloadEndpoint = "/download";
public const string HttpUploadEndpoint = "/upload";
Expand Down
18 changes: 18 additions & 0 deletions service/Abstractions/IKernelMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,22 @@ public Task<MemoryAnswer> AskAsync(
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// Search the given index for an answer to the given query.
/// </summary>
/// <param name="question">Question to answer</param>
/// <param name="index">Optional index name</param>
/// <param name="filter">Filter to match</param>
/// <param name="filters">Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list.</param>
/// <param name="minRelevance">Minimum Cosine Similarity required</param>
/// <param name="cancellationToken">Async task cancellation token</param>
/// <returns>Answer to the query, if possible</returns>
public IAsyncEnumerable<string> AskTextStreaming(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);
}
16 changes: 16 additions & 0 deletions service/Abstractions/Search/ISearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,22 @@ Task<MemoryAnswer> AskAsync(
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// Answer the given question, if possible, grounding the response with relevant memories matching the given criteria.
/// </summary>
/// <param name="index">Index (aka collection) to search for grounding information</param>
/// <param name="question">Question to answer</param>
/// <param name="filters">Filtering criteria to select memories to consider</param>
/// <param name="minRelevance">Minimum relevance of the memories considered</param>
/// <param name="cancellationToken">Async task cancellation token</param>
/// <returns>Answer to the given question as a stream (IAsyncEnumerable)</returns>
IAsyncEnumerable<string> AskTextStreamingAsync(
string index,
string question,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default);

/// <summary>
/// List the available memory indexes (aka collections).
/// </summary>
Expand Down
25 changes: 25 additions & 0 deletions service/Core/MemoryServerless.cs
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,29 @@ public Task<MemoryAnswer> AskAsync(
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}

/// <inheritdoc />
public IAsyncEnumerable<string> AskTextStreaming(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = new List<MemoryFilter>(); }

filters.Add(filter);
}

index = IndexName.CleanName(index, this._defaultIndexName);
return this._searchClient.AskTextStreamingAsync(
index: index,
question: question,
filters: filters,
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}
}
25 changes: 25 additions & 0 deletions service/Core/MemoryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,29 @@ public Task<MemoryAnswer> AskAsync(
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}

/// <inheritdoc />
public IAsyncEnumerable<string> AskTextStreaming(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection<MemoryFilter>? filters = null,
double minRelevance = 0,
CancellationToken cancellationToken = default)
{
if (filter != null)
{
if (filters == null) { filters = new List<MemoryFilter>(); }

filters.Add(filter);
}

index = IndexName.CleanName(index, this._defaultIndexName);
return this._searchClient.AskTextStreamingAsync(
index: index,
question: question,
filters: filters,
minRelevance: minRelevance,
cancellationToken: cancellationToken);
}
}
144 changes: 95 additions & 49 deletions service/Core/Search/SearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -201,6 +202,97 @@ public async Task<MemoryAnswer> AskAsync(
return noAnswerFound;
}

var (facts, relevantSources, factsAvailableCount, factsUsedCount) = await this.GetFactsAsync(question, index, filters, minRelevance, cancellationToken).ConfigureAwait(false);
var answer = noAnswerFound;
answer.RelevantSources = relevantSources;
if (factsAvailableCount > 0 && factsUsedCount == 0)
{
this._log.LogError("Unable to inject memories in the prompt, not enough tokens available");
noAnswerFound.NoResultReason = "Unable to use memories";
return noAnswerFound;
}

if (factsUsedCount == 0)
{
this._log.LogWarning("No memories available");
noAnswerFound.NoResultReason = "No memories available";
return noAnswerFound;
}

var text = new StringBuilder();
var charsGenerated = 0;
var watch = new Stopwatch();
watch.Restart();
await foreach (var x in this.GenerateAnswer(question, facts.ToString())
.WithCancellation(cancellationToken).ConfigureAwait(false))
{
text.Append(x);

if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30)
{
charsGenerated = text.Length;
this._log.LogTrace("{0} chars generated", charsGenerated);
}
}

watch.Stop();

answer.Result = text.ToString();
answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer);
if (answer.NoResult)
{
answer.NoResultReason = "No relevant memories found";
this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds);
}
else
{
this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds);
}

return answer;
}

/// <inheritdoc />
public async IAsyncEnumerable<string> AskTextStreamingAsync(string index, string question, ICollection<MemoryFilter>? filters = null, double minRelevance = 0, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(question))
{
this._log.LogWarning("No question provided");
yield return this._config.EmptyAnswer;
yield break;
}

var (facts, _, _, factsUsedCount) = await this.GetFactsAsync(question, index, filters, minRelevance, cancellationToken).ConfigureAwait(false);

if (factsUsedCount == 0)
{
this._log.LogError("No memories available or unable to inject memories in the prompt, not enough tokens available");
yield return this._config.EmptyAnswer;
yield break;
}

var text = new StringBuilder();
var charsGenerated = 0;
var watch = new Stopwatch();
watch.Restart();
await foreach (var x in this.GenerateAnswer(question, facts.ToString())
.WithCancellation(cancellationToken).ConfigureAwait(false))
{
yield return x;

if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30)
{
charsGenerated = text.Length;
this._log.LogTrace("{0} chars generated", charsGenerated);
}
}

watch.Stop();
}

private async Task<(string, List<Citation>, int, int)> GetFactsAsync(string question, string index, ICollection<MemoryFilter>? filters, double minRelevance, CancellationToken cancellationToken)
{
var relevantSources = new List<Citation>();
var facts = new StringBuilder();
var maxTokens = this._config.MaxAskPromptSize > 0
? this._config.MaxAskPromptSize
Expand All @@ -212,7 +304,6 @@ public async Task<MemoryAnswer> AskAsync(

var factsUsedCount = 0;
var factsAvailableCount = 0;
var answer = noAnswerFound;

this._log.LogTrace("Fetching relevant memories");
IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync(
Expand Down Expand Up @@ -267,11 +358,11 @@ public async Task<MemoryAnswer> AskAsync(
tokensAvailable -= size;

// If the file is already in the list of citations, only add the partition
var citation = answer.RelevantSources.FirstOrDefault(x => x.Link == linkToFile);
var citation = relevantSources.FirstOrDefault(x => x.Link == linkToFile);
if (citation == null)
{
citation = new Citation();
answer.RelevantSources.Add(citation);
relevantSources.Add(citation);
}

// Add the partition to the list of citations
Expand Down Expand Up @@ -299,52 +390,7 @@ public async Task<MemoryAnswer> AskAsync(
break;
}
}

if (factsAvailableCount > 0 && factsUsedCount == 0)
{
this._log.LogError("Unable to inject memories in the prompt, not enough tokens available");
noAnswerFound.NoResultReason = "Unable to use memories";
return noAnswerFound;
}

if (factsUsedCount == 0)
{
this._log.LogWarning("No memories available");
noAnswerFound.NoResultReason = "No memories available";
return noAnswerFound;
}

var text = new StringBuilder();
var charsGenerated = 0;
var watch = new Stopwatch();
watch.Restart();
await foreach (var x in this.GenerateAnswer(question, facts.ToString())
.WithCancellation(cancellationToken).ConfigureAwait(false))
{
text.Append(x);

if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30)
{
charsGenerated = text.Length;
this._log.LogTrace("{0} chars generated", charsGenerated);
}
}

watch.Stop();

answer.Result = text.ToString();
answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer);
if (answer.NoResult)
{
answer.NoResultReason = "No relevant memories found";
this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds);
}
else
{
this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds);
}

return answer;
return (facts.ToString(), relevantSources, factsAvailableCount, factsUsedCount);
}

private IAsyncEnumerable<string> GenerateAnswer(string question, string facts)
Expand Down
36 changes: 36 additions & 0 deletions service/Service.AspNetCore/WebAPIEndpoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using System.IO;
using Microsoft.AspNetCore.Http.HttpResults;
using Microsoft.KernelMemory.DocumentStorage;
using System.Runtime.CompilerServices;

namespace Microsoft.KernelMemory.Service.AspNetCore;

Expand All @@ -28,6 +29,7 @@ public static IEndpointRouteBuilder AddKernelMemoryEndpoints(
builder.AddDeleteIndexesEndpoint(apiPrefix, authFilter);
builder.AddDeleteDocumentsEndpoint(apiPrefix, authFilter);
builder.AddAskEndpoint(apiPrefix, authFilter);
builder.AddAskTextStreamingEndpoint(apiPrefix, authFilter);
builder.AddSearchEndpoint(apiPrefix, authFilter);
builder.AddUploadStatusEndpoint(apiPrefix, authFilter);
builder.AddGetDownloadEndpoint(apiPrefix, authFilter);
Expand Down Expand Up @@ -224,6 +226,40 @@ async Task<IResult> (
if (authFilter != null) { route.AddEndpointFilter(authFilter); }
}

public static void AddAskTextStreamingEndpoint(
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
{
RouteGroupBuilder group = builder.MapGroup(apiPrefix);

// Ask endpoint
var route = group.MapPost(Constants.HttpAskTextStreamingEndpoint, GetStreamAsync)
.Produces<MemoryAnswer>(StatusCodes.Status200OK)
.Produces<ProblemDetails>(StatusCodes.Status401Unauthorized)
.Produces<ProblemDetails>(StatusCodes.Status403Forbidden);

if (authFilter != null) { route.AddEndpointFilter(authFilter); }
}

private static async IAsyncEnumerable<string> GetStreamAsync(
MemoryQuery query,
IKernelMemory service,
ILogger<KernelMemoryWebAPI> log,
[EnumeratorCancellation] CancellationToken cancellationToken)
{
log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance);
await foreach (var textStream in service.AskTextStreaming(
question: query.Question,
index: query.Index,
filters: query.Filters,
minRelevance: query.MinRelevance,
cancellationToken: cancellationToken).ConfigureAwait(false))
{
await Task.Delay(1, cancellationToken).ConfigureAwait(false); // needed to fix issue where this is not streaming

yield return textStream ?? "";
}
}

public static void AddSearchEndpoint(
this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null)
{
Expand Down