Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 211 additions & 35 deletions pb/c1/connector/v2/entitlement.pb.go

Large diffs are not rendered by default.

402 changes: 402 additions & 0 deletions pb/c1/connector/v2/entitlement.pb.validate.go

Large diffs are not rendered by default.

37 changes: 35 additions & 2 deletions pb/c1/connector/v2/entitlement_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions pkg/connectorbuilder/connectorbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"context"
"errors"
"fmt"
"io"
"slices"
"sort"
"time"

"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -826,6 +828,54 @@ func (b *builderImpl) ListEntitlements(ctx context.Context, request *v2.Entitlem
return resp, nil
}

// ListEntitlementsStream returns all the entitlements for a given resource.
func (b *builderImpl) ListEntitlementsStream(s grpc.BidiStreamingServer[v2.EntitlementsServiceListEntitlementsRequestStream, v2.EntitlementsServiceListEntitlementsResponseStream]) error {
ctx := s.Context()
ctx, span := tracer.Start(ctx, "builderImpl.ListEntitlementsStream")
defer span.End()

for {
request, err := s.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}

start := b.nowFunc()
tt := tasks.ListEntitlementsType
rb, ok := b.resourceBuilders[request.Resource.Id.ResourceType]
if !ok {
b.m.RecordTaskFailure(ctx, tt, b.nowFunc().Sub(start))
return fmt.Errorf("error: list entitlements with unknown resource type %s", request.Resource.Id.ResourceType)
}

out, nextPageToken, annos, err := rb.Entitlements(ctx, request.Resource, &pagination.Token{
Size: int(request.PageSize),
Token: request.PageToken,
})
resp := &v2.EntitlementsServiceListEntitlementsResponseStream{
List: out,
NextPageToken: nextPageToken,
Annotations: annos,
}
if err != nil {
b.m.RecordTaskFailure(ctx, tt, b.nowFunc().Sub(start))
return fmt.Errorf("error: listing entitlements failed: %w", err)
}
if request.PageToken != "" && request.PageToken == nextPageToken {
b.m.RecordTaskFailure(ctx, tt, b.nowFunc().Sub(start))
return fmt.Errorf("error: listing entitlements failed: next page token is the same as the current page token. this is most likely a connector bug")
}

b.m.RecordTaskSuccess(ctx, tt, b.nowFunc().Sub(start))
if err := s.Send(resp); err != nil {
return err
}
}
}

// ListGrants lists all the grants for a given resource.
func (b *builderImpl) ListGrants(ctx context.Context, request *v2.GrantsServiceListGrantsRequest) (*v2.GrantsServiceListGrantsResponse, error) {
ctx, span := tracer.Start(ctx, "builderImpl.ListGrants")
Expand Down
1 change: 1 addition & 0 deletions pkg/dotc1z/c1file.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/doug-martin/goqu/v9"

// NOTE: required to register the dialect for goqu.
//
// If you remove this import, goqu.Dialect("sqlite3") will
Expand Down
42 changes: 42 additions & 0 deletions pkg/dotc1z/entitlements.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package dotc1z

import (
"context"
"errors"
"fmt"
"io"

"github.com/doug-martin/goqu/v9"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"

v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
Expand Down Expand Up @@ -78,6 +81,45 @@ func (c *C1File) ListEntitlements(ctx context.Context, request *v2.EntitlementsS
}, nil
}

func (c *C1File) ListEntitlementsStream(g grpc.BidiStreamingServer[v2.EntitlementsServiceListEntitlementsRequestStream, v2.EntitlementsServiceListEntitlementsResponseStream]) error {
ctx := g.Context()
ctx, span := tracer.Start(ctx, "C1File.ListEntitlementsStream")
defer span.End()

for {
req, err := g.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
// End of stream
return nil
}
return err
}

objs, nextPageToken, err := c.listConnectorObjects(ctx, entitlements.Name(), req)
if err != nil {
return fmt.Errorf("error listing entitlements: %w", err)
}

ret := make([]*v2.Entitlement, 0, len(objs))
for _, o := range objs {
en := &v2.Entitlement{}
err = proto.Unmarshal(o, en)
if err != nil {
return err
}
ret = append(ret, en)
}

if err := g.Send(&v2.EntitlementsServiceListEntitlementsResponseStream{
List: ret,
NextPageToken: nextPageToken,
}); err != nil {
return err
}
}
}

func (c *C1File) GetEntitlement(ctx context.Context, request *reader_v2.EntitlementsReaderServiceGetEntitlementRequest) (*reader_v2.EntitlementsReaderServiceGetEntitlementResponse, error) {
ctx, span := tracer.Start(ctx, "C1File.GetEntitlement")
defer span.End()
Expand Down
8 changes: 8 additions & 0 deletions pkg/sdk/empty_connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ func (n *emptyConnector) ListEntitlements(
}, nil
}

// ListEntitlementsStream returns a list of entitlements.
func (n *emptyConnector) ListEntitlementsStream(ctx context.Context, opts ...grpc.CallOption) (
grpc.BidiStreamingClient[v2.EntitlementsServiceListEntitlementsRequestStream, v2.EntitlementsServiceListEntitlementsResponseStream],
error,
) {
return MockBidiClient[v2.EntitlementsServiceListEntitlementsRequestStream, v2.EntitlementsServiceListEntitlementsResponseStream]{}, nil
}

// ListGrants returns a list of grants.
func (n *emptyConnector) ListGrants(ctx context.Context, request *v2.GrantsServiceListGrantsRequest, opts ...grpc.CallOption) (*v2.GrantsServiceListGrantsResponse, error) {
return &v2.GrantsServiceListGrantsResponse{
Expand Down
43 changes: 43 additions & 0 deletions pkg/sdk/mock_grpc_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package sdk

import (
"context"

"google.golang.org/grpc/metadata"
)

type MockBidiClient[Req any, Resp any] struct {
}

func (m MockBidiClient[Req, Resp]) Send(req *Req) error {
return nil
}

func (m MockBidiClient[Req, Resp]) Recv() (*Resp, error) {
var res Resp
return &res, nil
}

func (m MockBidiClient[Req, Resp]) Header() (metadata.MD, error) {
return nil, nil
}

func (m MockBidiClient[Req, Resp]) Trailer() metadata.MD {
return nil
}

func (m MockBidiClient[Req, Resp]) CloseSend() error {
return nil
}

func (m MockBidiClient[Req, Resp]) Context() context.Context {
return context.Background()
}

func (m MockBidiClient[Req, Resp]) SendMsg(msg any) error {
return nil
}

func (m MockBidiClient[Req, Resp]) RecvMsg(msg any) error {
return nil
}
42 changes: 32 additions & 10 deletions pkg/sync/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -214,6 +215,8 @@ type syncer struct {
skipEGForResourceType map[string]bool
skipEntitlementsAndGrants bool
resourceTypeTraits map[string][]v2.ResourceType_Trait

entitlementStreamer grpc.BidiStreamingClient[v2.EntitlementsServiceListEntitlementsRequestStream, v2.EntitlementsServiceListEntitlementsResponseStream]
}

const minCheckpointInterval = 10 * time.Second
Expand Down Expand Up @@ -1075,13 +1078,31 @@ func (s *syncer) syncEntitlementsForResource(ctx context.Context, resourceID *v2

pageToken := s.state.PageToken(ctx)

resp, err := s.connector.ListEntitlements(ctx, &v2.EntitlementsServiceListEntitlementsRequest{
if s.entitlementStreamer == nil {
entitlementStreamer, err := s.connector.ListEntitlementsStream(ctx)
if err != nil {
return err
}
s.entitlementStreamer = entitlementStreamer
}

err = s.entitlementStreamer.Send(&v2.EntitlementsServiceListEntitlementsRequestStream{
Resource: resourceResponse.Resource,
PageToken: pageToken,
})
if err != nil {
return err
}

resp, err := s.entitlementStreamer.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
s.entitlementStreamer = nil
return errors.New("expected entitlement response but got end of stream")
}
return err
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Stream Reuse Causes Sync Failures

The entitlementStreamer is reused across multiple resource syncs, which can lead to request/response mismatches on the bidirectional stream. If Send() or Recv() encounters a non-EOF error, the stream isn't reset, leaving it in a broken state for subsequent calls and causing sync failures.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted, I need to close send on the clients when we exit a sync. In truth, I don't think this is going to end up working at all. We could re-write baton-sdk entirely to be streamified from top to bottom but that's not the world we live in currently.

err = s.store.PutEntitlements(ctx, resp.List...)
if err != nil {
return err
Expand Down Expand Up @@ -1736,7 +1757,7 @@ func (s *syncer) SyncExternalResourcesWithGrantToEntitlement(ctx context.Context
}

ents := make([]*v2.Entitlement, 0)
principals := make([]*v2.Resource, 0)
resources := make([]*v2.Resource, 0)
resourceTypes := make([]*v2.ResourceType, 0)
resourceTypeIDs := mapset.NewSet[string]()
resourceIDs := make(map[string]*v2.ResourceId)
Expand Down Expand Up @@ -1784,17 +1805,17 @@ func (s *syncer) SyncExternalResourcesWithGrantToEntitlement(ctx context.Context
batonID := &v2.BatonID{}
resourceAnnos.Update(batonID)
resourceVal.Annotations = resourceAnnos
principals = append(principals, resourceVal)
resources = append(resources, resourceVal)
}

for _, principal := range principals {
rAnnos := annotations.Annotations(principal.GetAnnotations())
skipEnts := skipEGForResourceType[principal.Id.ResourceType] || rAnnos.Contains(&v2.SkipEntitlementsAndGrants{})
for _, resource := range resources {
rAnnos := annotations.Annotations(resource.GetAnnotations())
skipEnts := skipEGForResourceType[resource.Id.ResourceType] || rAnnos.Contains(&v2.SkipEntitlementsAndGrants{})
if skipEnts {
continue
}

resourceEnts, err := s.listExternalEntitlementsForResource(ctx, principal)
resourceEnts, err := s.listExternalEntitlementsForResource(ctx, resource)
if err != nil {
return err
}
Expand All @@ -1818,7 +1839,7 @@ func (s *syncer) SyncExternalResourcesWithGrantToEntitlement(ctx context.Context
return err
}

err = s.store.PutResources(ctx, principals...)
err = s.store.PutResources(ctx, resources...)
if err != nil {
return err
}
Expand All @@ -1835,12 +1856,12 @@ func (s *syncer) SyncExternalResourcesWithGrantToEntitlement(ctx context.Context

l.Info("Synced external resources for entitlement",
zap.Int("resource_type_count", len(resourceTypes)),
zap.Int("resource_count", len(principals)),
zap.Int("resource_count", len(resources)),
zap.Int("entitlement_count", len(ents)),
zap.Int("grant_count", len(grantsForEnts)),
)

err = s.processGrantsWithExternalPrincipals(ctx, principals)
err = s.processGrantsWithExternalPrincipals(ctx, resources)
if err != nil {
return err
}
Expand Down Expand Up @@ -1984,6 +2005,7 @@ func (s *syncer) listExternalResourcesForResourceType(ctx context.Context, resou

func (s *syncer) listExternalEntitlementsForResource(ctx context.Context, resource *v2.Resource) ([]*v2.Entitlement, error) {
ents := make([]*v2.Entitlement, 0)

entitlementToken := ""
for {
entitlementsList, err := s.externalResourceReader.ListEntitlements(ctx, &v2.EntitlementsServiceListEntitlementsRequest{
Expand Down
Loading
Loading