Skip to content
Snippets Groups Projects
Commit 4472c661 authored by Rinat Shigapov's avatar Rinat Shigapov Committed by Kai Chen
Browse files

fix debug completed logging (#1897)

parent 8df11f96
No related branches found
No related tags found
No related merge requests found
...@@ -14,8 +14,8 @@ DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False)) ...@@ -14,8 +14,8 @@ DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def completed(trace_name="", async def completed(trace_name='',
name="", name='',
sleep_interval=0.05, sleep_interval=0.05,
streams: List[torch.cuda.Stream] = None): streams: List[torch.cuda.Stream] = None):
""" """
...@@ -42,7 +42,7 @@ async def completed(trace_name="", ...@@ -42,7 +42,7 @@ async def completed(trace_name="",
stream_before_context_switch.record_event(start) stream_before_context_switch.record_event(start)
cpu_start = time.monotonic() cpu_start = time.monotonic()
logger.debug("%s %s starting, streams: %s", trace_name, name, streams) logger.debug('%s %s starting, streams: %s', trace_name, name, streams)
grad_enabled_before = torch.is_grad_enabled() grad_enabled_before = torch.is_grad_enabled()
try: try:
yield yield
...@@ -60,37 +60,41 @@ async def completed(trace_name="", ...@@ -60,37 +60,41 @@ async def completed(trace_name="",
# observed change of torch.is_grad_enabled() during concurrent run of # observed change of torch.is_grad_enabled() during concurrent run of
# async_test_bboxes code # async_test_bboxes code
assert grad_enabled_before == grad_enabled_after, \ assert (grad_enabled_before == grad_enabled_after
"Unexpected is_grad_enabled() value change" ), 'Unexpected is_grad_enabled() value change'
are_done = [e.query() for e in end_events] are_done = [e.query() for e in end_events]
logger.debug("%s %s completed: %s streams: %s", trace_name, name, logger.debug('%s %s completed: %s streams: %s', trace_name, name,
are_done, streams) are_done, streams)
with torch.cuda.stream(stream_before_context_switch): with torch.cuda.stream(stream_before_context_switch):
while not all(are_done): while not all(are_done):
await asyncio.sleep(sleep_interval) await asyncio.sleep(sleep_interval)
are_done = [e.query() for e in end_events] are_done = [e.query() for e in end_events]
logger.debug("%s %s completed: %s streams: %s", trace_name, logger.debug(
name, are_done, streams) '%s %s completed: %s streams: %s',
trace_name,
name,
are_done,
streams,
)
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
assert current_stream == stream_before_context_switch assert current_stream == stream_before_context_switch
if DEBUG_COMPLETED_TIME: if DEBUG_COMPLETED_TIME:
cpu_time = (cpu_end - cpu_start) * 1000 cpu_time = (cpu_end - cpu_start) * 1000
stream_times_ms = "" stream_times_ms = ''
for i, stream in enumerate(streams): for i, stream in enumerate(streams):
elapsed_time = start.elapsed_time(end_events[i]) elapsed_time = start.elapsed_time(end_events[i])
stream_times_ms += " {stream} {elapsed_time:.2f} ms".format( stream_times_ms += ' {} {:.2f} ms'.format(stream, elapsed_time)
stream, elapsed_time) logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time,
logger.info("{trace_name} {name} cpu_time {cpu_time:.2f} ms", stream_times_ms)
trace_name, name, cpu_time, stream_times_ms)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def concurrent(streamqueue: asyncio.Queue, async def concurrent(streamqueue: asyncio.Queue,
trace_name="concurrent", trace_name='concurrent',
name="stream"): name='stream'):
"""Run code concurrently in different streams. """Run code concurrently in different streams.
:param streamqueue: asyncio.Queue instance. :param streamqueue: asyncio.Queue instance.
...@@ -110,12 +114,12 @@ async def concurrent(streamqueue: asyncio.Queue, ...@@ -110,12 +114,12 @@ async def concurrent(streamqueue: asyncio.Queue,
try: try:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
logger.debug("%s %s is starting, stream: %s", trace_name, name, logger.debug('%s %s is starting, stream: %s', trace_name, name,
stream) stream)
yield yield
current = torch.cuda.current_stream() current = torch.cuda.current_stream()
assert current == stream assert current == stream
logger.debug("%s %s has finished, stream: %s", trace_name, logger.debug('%s %s has finished, stream: %s', trace_name,
name, stream) name, stream)
finally: finally:
streamqueue.task_done() streamqueue.task_done()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment