From 5692c4b5c7eca1a24b49108c607967eb59f3c837 Mon Sep 17 00:00:00 2001 From: Conrad Weidenkeller Date: Mon, 29 Dec 2025 18:57:26 -0600 Subject: [PATCH] feat(db): Create generic transaction handler BED-7079 --- .../src/database/database_integration_test.go | 124 +----------- cmd/api/src/database/db.go | 77 +++---- cmd/api/src/database/transactions.go | 141 +++++++++++++ .../database/transactions_integration_test.go | 188 ++++++++++++++++++ 4 files changed, 373 insertions(+), 157 deletions(-) create mode 100644 cmd/api/src/database/transactions.go create mode 100644 cmd/api/src/database/transactions_integration_test.go diff --git a/cmd/api/src/database/database_integration_test.go b/cmd/api/src/database/database_integration_test.go index d3f7522ce7..35f70b6dd0 100644 --- a/cmd/api/src/database/database_integration_test.go +++ b/cmd/api/src/database/database_integration_test.go @@ -20,17 +20,17 @@ package database_test import ( "context" - "database/sql" "fmt" "net/url" "strings" "testing" "github.com/peterldowns/pgtestdb" + "github.com/stretchr/testify/require" + "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/database" "github.com/specterops/bloodhound/cmd/api/src/test/integration/utils" - "github.com/stretchr/testify/require" ) type IntegrationTestSuite struct { @@ -116,123 +116,3 @@ func teardownIntegrationTestSuite(t *testing.T, suite *IntegrationTestSuite) { suite.BHDatabase.Close(suite.Context) } } - -func TestTransaction(t *testing.T) { - t.Run("Success: operations commit together", func(t *testing.T) { - testSuite := setupIntegrationTestSuite(t) - defer teardownIntegrationTestSuite(t, &testSuite) - - // Get initial flag state - flag, err := testSuite.BHDatabase.GetFlagByKey(testSuite.Context, "opengraph_search") - require.NoError(t, err) - originalEnabled := flag.Enabled - - // Update flag in a transaction - err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - flag.Enabled = !originalEnabled - return tx.SetFlag(testSuite.Context, flag) - }) - require.NoError(t, err) - - // Verify the flag was updated - updatedFlag, err := testSuite.BHDatabase.GetFlagByKey(testSuite.Context, "opengraph_search") - require.NoError(t, err) - require.Equal(t, !originalEnabled, updatedFlag.Enabled) - }) - - t.Run("Rollback: error causes operations to rollback", func(t *testing.T) { - testSuite := setupIntegrationTestSuite(t) - defer teardownIntegrationTestSuite(t, &testSuite) - - // Get initial flag state - flag, err := testSuite.BHDatabase.GetFlagByKey(testSuite.Context, "opengraph_search") - require.NoError(t, err) - originalEnabled := flag.Enabled - - // Update flag then return error - should rollback - expectedErr := fmt.Errorf("intentional error to trigger rollback") - err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - flag.Enabled = !originalEnabled - if err := tx.SetFlag(testSuite.Context, flag); err != nil { - return err - } - return expectedErr - }) - require.ErrorIs(t, err, expectedErr) - - // Verify the flag was NOT updated (rolled back) - unchangedFlag, err := testSuite.BHDatabase.GetFlagByKey(testSuite.Context, "opengraph_search") - require.NoError(t, err) - require.Equal(t, originalEnabled, unchangedFlag.Enabled) - }) - - t.Run("Success: nested method calls work within transaction", func(t *testing.T) { - testSuite := setupIntegrationTestSuite(t) - defer teardownIntegrationTestSuite(t, &testSuite) - - // Verify we can call multiple different methods in a transaction - err := testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - // Call GetAllFlags - read operation - flags, err := tx.GetAllFlags(testSuite.Context) - if err != nil { - return err - } - require.NotEmpty(t, flags) - - // Call GetFlagByKey - another read operation - _, err = tx.GetFlagByKey(testSuite.Context, "opengraph_search") - return err - }) - require.NoError(t, err) - }) - - t.Run("Success: transaction with isolation level option", func(t *testing.T) { - testSuite := setupIntegrationTestSuite(t) - defer teardownIntegrationTestSuite(t, &testSuite) - - // Get initial flag state - flag, err := testSuite.BHDatabase.GetFlagByKey(testSuite.Context, "opengraph_search") - require.NoError(t, err) - originalEnabled := flag.Enabled - - // Update flag in a transaction with serializable isolation - err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - flag.Enabled = !originalEnabled - return tx.SetFlag(testSuite.Context, flag) - }, &sql.TxOptions{Isolation: sql.LevelSerializable}) - require.NoError(t, err) - - // Verify the flag was updated - updatedFlag, err := testSuite.BHDatabase.GetFlagByKey(testSuite.Context, "opengraph_search") - require.NoError(t, err) - require.Equal(t, !originalEnabled, updatedFlag.Enabled) - }) - - t.Run("Success: read-only transaction", func(t *testing.T) { - testSuite := setupIntegrationTestSuite(t) - defer teardownIntegrationTestSuite(t, &testSuite) - - // Read-only transaction should work for read operations - err := testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - _, err := tx.GetAllFlags(testSuite.Context) - return err - }, &sql.TxOptions{ReadOnly: true}) - require.NoError(t, err) - }) - - t.Run("Fail: write in read-only transaction", func(t *testing.T) { - testSuite := setupIntegrationTestSuite(t) - defer teardownIntegrationTestSuite(t, &testSuite) - - // Get a flag to modify - flag, err := testSuite.BHDatabase.GetFlagByKey(testSuite.Context, "opengraph_search") - require.NoError(t, err) - - // Attempting to write in a read-only transaction should fail - err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - flag.Enabled = !flag.Enabled - return tx.SetFlag(testSuite.Context, flag) - }, &sql.TxOptions{ReadOnly: true}) - require.Error(t, err) - }) -} diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 7f776d595f..98810f8650 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -20,13 +20,15 @@ package database import ( "context" - "database/sql" "errors" "fmt" "log/slog" "time" "github.com/gofrs/uuid" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "github.com/specterops/bloodhound/cmd/api/src/auth" "github.com/specterops/bloodhound/cmd/api/src/database/migration" "github.com/specterops/bloodhound/cmd/api/src/model" @@ -34,8 +36,6 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/services/agi" "github.com/specterops/bloodhound/cmd/api/src/services/dataquality" "github.com/specterops/bloodhound/cmd/api/src/services/upload" - "gorm.io/driver/postgres" - "gorm.io/gorm" ) var ( @@ -192,10 +192,33 @@ type Database interface { } type BloodhoundDB struct { - db *gorm.DB + TransactableDB[*BloodhoundDB, *gorm.DB] idResolver auth.IdentityResolver // TODO: this really needs to be elsewhere. something something separation of concerns } +func NewBloodhoundDB(db *gorm.DB, idResolver auth.IdentityResolver) *BloodhoundDB { + bhdb := &BloodhoundDB{ + idResolver: idResolver, + } + + // Define factory once - reused for nested transactions (no recursion) + var factory TxFactory[*BloodhoundDB, *gorm.DB] + factory = func(txDb *gorm.DB) *BloodhoundDB { + txBhdb := &BloodhoundDB{idResolver: idResolver} + txBhdb.ConfigureTransactable(txDb, + GormExecutor[*BloodhoundDB](), + WithTxFactory(factory), + ) + return txBhdb + } + + bhdb.ConfigureTransactable(db, + GormExecutor[*BloodhoundDB](), + WithTxFactory(factory), + ) + return bhdb +} + func (s *BloodhoundDB) Close(ctx context.Context) { if sqlDBRef, err := s.db.WithContext(ctx).DB(); err != nil { slog.ErrorContext(ctx, fmt.Sprintf("Failed to fetch SQL DB reference from GORM: %v", err)) @@ -222,37 +245,6 @@ func (s *BloodhoundDB) Scope(scopeFuncs ...ScopeFunc) *gorm.DB { return s.db.Scopes(scopes...) } -func NewBloodhoundDB(db *gorm.DB, idResolver auth.IdentityResolver) *BloodhoundDB { - return &BloodhoundDB{db: db, idResolver: idResolver} -} - -// Transaction executes the given function within a database transaction. -// The function receives a new BloodhoundDB instance backed by the transaction, -// allowing all existing methods to participate in the transaction. -// If the function returns an error, the transaction is rolled back. -// If the function returns nil, the transaction is committed. -// Optional sql.TxOptions can be provided to configure isolation level and read-only mode. -func (s *BloodhoundDB) Transaction(ctx context.Context, fn func(tx *BloodhoundDB) error, opts ...*sql.TxOptions) error { - return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { - return fn(NewBloodhoundDB(tx, s.idResolver)) - }, opts...) -} - -func OpenDatabase(connection string) (*gorm.DB, error) { - gormConfig := &gorm.Config{ - Logger: &GormLogAdapter{ - SlowQueryErrorThreshold: time.Second * 10, - SlowQueryWarnThreshold: time.Second * 1, - }, - } - - if db, err := gorm.Open(postgres.Open(connection), gormConfig); err != nil { - return nil, err - } else { - return db, nil - } -} - func (s *BloodhoundDB) RawDelete(value any) error { return CheckError(s.db.Delete(value)) } @@ -286,3 +278,18 @@ func (s *BloodhoundDB) Migrate(ctx context.Context) error { return nil } + +func OpenDatabase(connection string) (*gorm.DB, error) { + gormConfig := &gorm.Config{ + Logger: &GormLogAdapter{ + SlowQueryErrorThreshold: time.Second * 10, + SlowQueryWarnThreshold: time.Second * 1, + }, + } + + if db, err := gorm.Open(postgres.Open(connection), gormConfig); err != nil { + return nil, err + } else { + return db, nil + } +} diff --git a/cmd/api/src/database/transactions.go b/cmd/api/src/database/transactions.go new file mode 100644 index 0000000000..8c861c318a --- /dev/null +++ b/cmd/api/src/database/transactions.go @@ -0,0 +1,141 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package database + +import ( + "context" + "reflect" + + "gorm.io/gorm" +) + +type Transactable[T any] interface { + Transaction(ctx context.Context, fn func(tx T) error) error +} + +type TransactionExecutor[D any] interface { + RunTransaction(ctx context.Context, db D, fn func(txDb D) error) error +} + +type TxFactory[T any, D any] func(db D) T + +type handle[D any] struct { + db D +} + +type TransactableDB[T Transactable[T], D any] struct { + *handle[D] + executor TransactionExecutor[D] + txFactory TxFactory[T, D] +} + +func (t *TransactableDB[T, D]) Transaction(ctx context.Context, fn func(tx T) error) error { + return t.executor.RunTransaction(ctx, t.db, func(txDb D) error { + return fn(t.txFactory(txDb)) + }) +} + +func (t *TransactableDB[T, D]) DB() D { + return t.db +} + +type TransactableOption[T Transactable[T], D any] func(*TransactableDB[T, D]) + +func WithExecutor[T Transactable[T], D any](executor TransactionExecutor[D]) TransactableOption[T, D] { + return func(t *TransactableDB[T, D]) { + t.executor = executor + } +} + +func WithTxFactory[T Transactable[T], D any](factory TxFactory[T, D]) TransactableOption[T, D] { + return func(t *TransactableDB[T, D]) { + t.txFactory = factory + } +} + +func (t *TransactableDB[T, D]) ConfigureTransactable(db D, opts ...TransactableOption[T, D]) { + t.handle = &handle[D]{db: db} + for _, opt := range opts { + opt(t) + } +} + +func GormExecutor[T Transactable[T]]() TransactableOption[T, *gorm.DB] { + return WithExecutor[T, *gorm.DB](gormExecutor{}) +} + +type gormExecutor struct{} + +func (gormExecutor) RunTransaction(ctx context.Context, db *gorm.DB, fn func(tx *gorm.DB) error) error { + return db.WithContext(ctx).Transaction(fn) +} + +type Wirable[D any] interface { + Wire(db D) +} + +func AutowireEmbedded[D any](parent any, db D) { + parentVal := reflect.ValueOf(parent) + if parentVal.Kind() == reflect.Ptr { + parentVal = parentVal.Elem() + } + if parentVal.Kind() != reflect.Struct { + return + } + + wirableType := reflect.TypeOf((*Wirable[D])(nil)).Elem() + + for i := 0; i < parentVal.NumField(); i++ { + field := parentVal.Field(i) + fieldType := parentVal.Type().Field(i) + + if !fieldType.Anonymous || !fieldType.IsExported() { + continue + } + + var fieldPtr reflect.Value + if field.Kind() == reflect.Ptr { + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + fieldPtr = field + } else if field.CanAddr() { + fieldPtr = field.Addr() + } else { + continue + } + + if fieldPtr.Type().Implements(wirableType) { + wireMethod := fieldPtr.MethodByName("Wire") + if wireMethod.IsValid() { + wireMethod.Call([]reflect.Value{reflect.ValueOf(db)}) + } + } + } +} + +func WithAutowire[T Transactable[T], D any](parent any) TransactableOption[T, D] { + return func(t *TransactableDB[T, D]) { + AutowireEmbedded(parent, t.db) + } +} + +func WireDB[D any](db D, fields ...*D) { + for _, f := range fields { + *f = db + } +} diff --git a/cmd/api/src/database/transactions_integration_test.go b/cmd/api/src/database/transactions_integration_test.go new file mode 100644 index 0000000000..952b49e439 --- /dev/null +++ b/cmd/api/src/database/transactions_integration_test.go @@ -0,0 +1,188 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration + +package database_test + +import ( + "context" + "errors" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/stretchr/testify/require" +) + +func TestTransaction_Commit(t *testing.T) { + suite := setupIntegrationTestSuite(t) + + var createdGroup model.AssetGroup + err := suite.BHDatabase.Transaction(suite.Context, func(tx *database.BloodhoundDB) error { + var err error + createdGroup, err = tx.CreateAssetGroup(suite.Context, "test-transaction-group", "test-tx-tag", false) + return err + }) + + require.NoError(t, err) + require.NotZero(t, createdGroup.ID) + + fetchedGroup, err := suite.BHDatabase.GetAssetGroup(suite.Context, createdGroup.ID) + require.NoError(t, err) + require.Equal(t, "test-transaction-group", fetchedGroup.Name) +} + +func TestTransaction_Rollback(t *testing.T) { + suite := setupIntegrationTestSuite(t) + + rollbackErr := errors.New("intentional rollback") + var attemptedGroupID int32 + + err := suite.BHDatabase.Transaction(suite.Context, func(tx *database.BloodhoundDB) error { + createdGroup, err := tx.CreateAssetGroup(suite.Context, "should-rollback-group", "rollback-tag", false) + if err != nil { + return err + } + attemptedGroupID = createdGroup.ID + return rollbackErr + }) + + require.ErrorIs(t, err, rollbackErr) + + _, err = suite.BHDatabase.GetAssetGroup(suite.Context, attemptedGroupID) + require.ErrorIs(t, err, database.ErrNotFound) +} + +func TestTransaction_MultipleOperations(t *testing.T) { + suite := setupIntegrationTestSuite(t) + + var group1, group2 model.AssetGroup + err := suite.BHDatabase.Transaction(suite.Context, func(tx *database.BloodhoundDB) error { + var err error + group1, err = tx.CreateAssetGroup(suite.Context, "multi-op-group-1", "multi-tag-1", false) + if err != nil { + return err + } + + group2, err = tx.CreateAssetGroup(suite.Context, "multi-op-group-2", "multi-tag-2", false) + return err + }) + + require.NoError(t, err) + + fetchedGroup1, err := suite.BHDatabase.GetAssetGroup(suite.Context, group1.ID) + require.NoError(t, err) + require.Equal(t, "multi-op-group-1", fetchedGroup1.Name) + + fetchedGroup2, err := suite.BHDatabase.GetAssetGroup(suite.Context, group2.ID) + require.NoError(t, err) + require.Equal(t, "multi-op-group-2", fetchedGroup2.Name) +} + +func TestTransaction_PartialRollback(t *testing.T) { + suite := setupIntegrationTestSuite(t) + + outsideGroup, err := suite.BHDatabase.CreateAssetGroup(suite.Context, "outside-tx-group", "outside-tag", false) + require.NoError(t, err) + + rollbackErr := errors.New("partial rollback") + err = suite.BHDatabase.Transaction(suite.Context, func(tx *database.BloodhoundDB) error { + _, err := tx.CreateAssetGroup(suite.Context, "inside-tx-group", "inside-tag", false) + if err != nil { + return err + } + return rollbackErr + }) + + require.ErrorIs(t, err, rollbackErr) + + fetchedOutside, err := suite.BHDatabase.GetAssetGroup(suite.Context, outsideGroup.ID) + require.NoError(t, err) + require.Equal(t, "outside-tx-group", fetchedOutside.Name) +} + +func TestTransaction_ContextCancellation(t *testing.T) { + suite := setupIntegrationTestSuite(t) + + ctx, cancel := context.WithCancel(suite.Context) + + err := suite.BHDatabase.Transaction(ctx, func(tx *database.BloodhoundDB) error { + cancel() + _, err := tx.CreateAssetGroup(ctx, "cancelled-group", "cancelled-tag", false) + return err + }) + + require.Error(t, err) +} + +func TestTransaction_NestedTransactions(t *testing.T) { + suite := setupIntegrationTestSuite(t) + + var outerGroup, innerGroup model.AssetGroup + err := suite.BHDatabase.Transaction(suite.Context, func(outerTx *database.BloodhoundDB) error { + var err error + outerGroup, err = outerTx.CreateAssetGroup(suite.Context, "nested-outer-group", "nested-outer-tag", false) + if err != nil { + return err + } + + return outerTx.Transaction(suite.Context, func(innerTx *database.BloodhoundDB) error { + innerGroup, err = innerTx.CreateAssetGroup(suite.Context, "nested-inner-group", "nested-inner-tag", false) + return err + }) + }) + + require.NoError(t, err) + + fetchedOuter, err := suite.BHDatabase.GetAssetGroup(suite.Context, outerGroup.ID) + require.NoError(t, err) + require.Equal(t, "nested-outer-group", fetchedOuter.Name) + + fetchedInner, err := suite.BHDatabase.GetAssetGroup(suite.Context, innerGroup.ID) + require.NoError(t, err) + require.Equal(t, "nested-inner-group", fetchedInner.Name) +} + +func TestTransaction_NestedRollback(t *testing.T) { + suite := setupIntegrationTestSuite(t) + + var outerGroupID int32 + nestedErr := errors.New("nested transaction failure") + + err := suite.BHDatabase.Transaction(suite.Context, func(outerTx *database.BloodhoundDB) error { + outerGroup, err := outerTx.CreateAssetGroup(suite.Context, "nested-rollback-outer", "nested-rb-outer-tag", false) + if err != nil { + return err + } + outerGroupID = outerGroup.ID + + innerErr := outerTx.Transaction(suite.Context, func(innerTx *database.BloodhoundDB) error { + _, err := innerTx.CreateAssetGroup(suite.Context, "nested-rollback-inner", "nested-rb-inner-tag", false) + if err != nil { + return err + } + return nestedErr + }) + + return innerErr + }) + + require.ErrorIs(t, err, nestedErr) + + _, err = suite.BHDatabase.GetAssetGroup(suite.Context, outerGroupID) + require.ErrorIs(t, err, database.ErrNotFound) +}