Skip to content

Commit 1fff5a6

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 835353176
1 parent 757c679 commit 1fff5a6

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

grain/_src/python/dataset/dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence
5151
import functools
5252
import json
53+
import threading
5354
from typing import Any, Generic, TypeVar, Union, cast, overload
5455
import warnings
5556

@@ -67,6 +68,7 @@
6768
from grain._src.core import monitoring
6869

6970

71+
_usage_logging_lock = threading.Lock()
7072
_api_usage_counter = monitoring.Counter(
7173
"/grain/python/lazy_dataset/api",
7274
metadata=monitoring.Metadata(
@@ -358,7 +360,8 @@ def __init__(self, parents: MapDataset | Sequence[MapDataset] = ()):
358360
parents = tuple(parents)
359361
super().__init__(parents)
360362
self._parents = cast(Sequence[MapDataset], self._parents)
361-
usage_logging.log_event("MapDataset", tag_3="PyGrain")
363+
with _usage_logging_lock:
364+
usage_logging.log_event("MapDataset", tag_3="PyGrain")
362365
_api_usage_counter.Increment("MapDataset")
363366

364367
@property
@@ -977,7 +980,8 @@ def __init__(
977980
self._parents = cast(
978981
Sequence[Union[MapDataset, IterDataset]], self._parents
979982
)
980-
usage_logging.log_event("IterDataset", tag_3="PyGrain")
983+
with _usage_logging_lock:
984+
usage_logging.log_event("IterDataset", tag_3="PyGrain")
981985
_api_usage_counter.Increment("IterDataset")
982986

983987
@property

grain/_src/python/shared_memory_array_test.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,31 +153,47 @@ def test_del_many_async_reuse_pool(self):
153153
)
154154
original_close_shm_async = SharedMemoryArray.close_shm_async
155155

156+
# Use a semaphore to track completed async deletions.
157+
completed_sem = threading.Semaphore(0)
158+
# Use a thread-safe counter because mock.call_count is not thread-safe in
159+
# free-threaded Python.
160+
call_count = 0
161+
count_lock = threading.Lock()
162+
156163
def my_close_shm_async(shm, unlink_on_del):
157164
original_close_shm_async(shm, unlink_on_del)
165+
with count_lock:
166+
nonlocal call_count
167+
call_count += 1
168+
completed_sem.release()
158169

159170
with mock.patch.object(
160171
SharedMemoryArray, "close_shm_async", side_effect=my_close_shm_async
161-
) as mock_close_shm_async:
172+
):
162173
with self.subTest("first_round_of_requests"):
163174
shm_metadatas = [
164175
_create_and_delete_shm() for _ in range(max_outstanding_requests)
165176
]
166177
for metadata in shm_metadatas:
167178
_wait_for_deletion(metadata)
168-
self.assertEqual(
169-
max_outstanding_requests, mock_close_shm_async.call_count
170-
)
179+
# Wait for all async deletions to complete to ensure the semaphore in
180+
# SharedMemoryArray is released.
181+
for _ in range(max_outstanding_requests):
182+
completed_sem.acquire()
183+
with count_lock:
184+
self.assertEqual(max_outstanding_requests, call_count)
185+
171186
with self.subTest("second_round_of_requests"):
172187
# Do it again to make sure the pool is reused.
173188
shm_metadatas = [
174189
_create_and_delete_shm() for _ in range(max_outstanding_requests)
175190
]
176191
for metadata in shm_metadatas:
177192
_wait_for_deletion(metadata)
178-
self.assertEqual(
179-
2 * max_outstanding_requests, mock_close_shm_async.call_count
180-
)
193+
for _ in range(max_outstanding_requests):
194+
completed_sem.acquire()
195+
with count_lock:
196+
self.assertEqual(2 * max_outstanding_requests, call_count)
181197

182198

183199
if __name__ == "__main__":

0 commit comments

Comments
 (0)