diff --git a/db_spec/pg/repository/query_spec.cr b/db_spec/pg/repository/query_spec.cr index 288dcc2..6046e32 100644 --- a/db_spec/pg/repository/query_spec.cr +++ b/db_spec/pg/repository/query_spec.cr @@ -66,6 +66,66 @@ describe "Repository(Postgres)#query" do user.referrer.not_nil!.uuid.should be_a(UUID) end end + + context "in a transaction" do + db = repo.db + query = User + .update + .set(favorite_numbers: [17]) + .where(uuid: user.uuid.not_nil!) + .returning(:favorite_numbers) + + repo.db.transaction do |tx| + repo.db = tx + repo.query(query) + ensure + repo.db = db + end + + it "updates the user" do + query = User.query + .select(:favorite_numbers, :uuid) + .where(uuid: user.uuid.not_nil!) + + updated_user = repo.query(query).first + updated_user.favorite_numbers.should eq([17]) + end + + context "raising an error" do + error_was_raised = false + + query = User + .update + .set(favorite_numbers: [13]) + .where(uuid: user.uuid.not_nil!) + .returning(:favorite_numbers) + + begin + repo.db.transaction do |tx| + repo.db = tx + repo.query(query) + raise "error" + ensure + repo.db = db + end + rescue + error_was_raised = true + end + + it "raises the error" do + error_was_raised.should be_truthy + end + + it "rolls back the transaction" do + query = User.query + .select(:favorite_numbers, :uuid) + .where(uuid: user.uuid.not_nil!) + + updated_user = repo.query(query).first + updated_user.favorite_numbers.should_not eq([13]) + end + end + end end describe "where" do diff --git a/db_spec/sqlite3/repository/query_spec.cr b/db_spec/sqlite3/repository/query_spec.cr index 6c31edd..4934433 100644 --- a/db_spec/sqlite3/repository/query_spec.cr +++ b/db_spec/sqlite3/repository/query_spec.cr @@ -62,6 +62,64 @@ describe "Repository(Postgres)#query" do cursor.rows_affected.should eq 1 end end + + context "in a transaction" do + db = repo.db + query = User + .update + .set(favorite_numbers: [17]) + .where(id: user.id.not_nil!) + + repo.db.transaction do |tx| + repo.db = tx + repo.query(query) + ensure + repo.db = db + end + + it "updates the user" do + query = User.query + .select(:favorite_numbers, :id) + .where(id: user.id.not_nil!) + + updated_user = repo.query(query).first + updated_user.favorite_numbers.should eq([17]) + end + + context "raising an error" do + error_was_raised = false + + query = User + .update + .set(favorite_numbers: [13]) + .where(id: user.id.not_nil!) + + begin + repo.db.transaction do |tx| + repo.db = tx + repo.query(query) + raise "error" + ensure + repo.db = db + end + rescue + error_was_raised = true + end + + it "raises the error" do + error_was_raised.should be_truthy + end + + it "rolls back the transaction" do + query = User.query + .select(:favorite_numbers, :id) + .where(id: user.id.not_nil!) + + updated_user = repo.query(query).first + updated_user.favorite_numbers.should_not eq([13]) + end + end + end end describe "where" do diff --git a/src/onyx-sql/repository.cr b/src/onyx-sql/repository.cr index b0065e1..bcb633a 100644 --- a/src/onyx-sql/repository.cr +++ b/src/onyx-sql/repository.cr @@ -32,6 +32,10 @@ module Onyx::SQL def initialize(@db : ::DB::Database | ::DB::Connection, @logger : Logger = Logger::Standard.new) end + def db=(transaction : ::DB::Transaction) + self.db = transaction.connection + end + protected def postgresql? {% if Object.all_subclasses.any? { |sc| sc.stringify == "PG::Driver" } %} return db.is_a?(PG::Driver)