Skip to content
Snippets Groups Projects
Commit 4472c661 authored by Rinat Shigapov's avatar Rinat Shigapov Committed by Kai Chen
Browse files

fix debug completed logging (#1897)

parent 8df11f96
No related branches found
No related tags found
No related merge requests found
......@@ -14,8 +14,8 @@ DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))
@contextlib.asynccontextmanager
async def completed(trace_name="",
name="",
async def completed(trace_name='',
name='',
sleep_interval=0.05,
streams: List[torch.cuda.Stream] = None):
"""
......@@ -42,7 +42,7 @@ async def completed(trace_name="",
stream_before_context_switch.record_event(start)
cpu_start = time.monotonic()
logger.debug("%s %s starting, streams: %s", trace_name, name, streams)
logger.debug('%s %s starting, streams: %s', trace_name, name, streams)
grad_enabled_before = torch.is_grad_enabled()
try:
yield
......@@ -60,37 +60,41 @@ async def completed(trace_name="",
# observed change of torch.is_grad_enabled() during concurrent run of
# async_test_bboxes code
assert grad_enabled_before == grad_enabled_after, \
"Unexpected is_grad_enabled() value change"
assert (grad_enabled_before == grad_enabled_after
), 'Unexpected is_grad_enabled() value change'
are_done = [e.query() for e in end_events]
logger.debug("%s %s completed: %s streams: %s", trace_name, name,
logger.debug('%s %s completed: %s streams: %s', trace_name, name,
are_done, streams)
with torch.cuda.stream(stream_before_context_switch):
while not all(are_done):
await asyncio.sleep(sleep_interval)
are_done = [e.query() for e in end_events]
logger.debug("%s %s completed: %s streams: %s", trace_name,
name, are_done, streams)
logger.debug(
'%s %s completed: %s streams: %s',
trace_name,
name,
are_done,
streams,
)
current_stream = torch.cuda.current_stream()
assert current_stream == stream_before_context_switch
if DEBUG_COMPLETED_TIME:
cpu_time = (cpu_end - cpu_start) * 1000
stream_times_ms = ""
stream_times_ms = ''
for i, stream in enumerate(streams):
elapsed_time = start.elapsed_time(end_events[i])
stream_times_ms += " {stream} {elapsed_time:.2f} ms".format(
stream, elapsed_time)
logger.info("{trace_name} {name} cpu_time {cpu_time:.2f} ms",
trace_name, name, cpu_time, stream_times_ms)
stream_times_ms += ' {} {:.2f} ms'.format(stream, elapsed_time)
logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time,
stream_times_ms)
@contextlib.asynccontextmanager
async def concurrent(streamqueue: asyncio.Queue,
trace_name="concurrent",
name="stream"):
trace_name='concurrent',
name='stream'):
"""Run code concurrently in different streams.
:param streamqueue: asyncio.Queue instance.
......@@ -110,12 +114,12 @@ async def concurrent(streamqueue: asyncio.Queue,
try:
with torch.cuda.stream(stream):
logger.debug("%s %s is starting, stream: %s", trace_name, name,
logger.debug('%s %s is starting, stream: %s', trace_name, name,
stream)
yield
current = torch.cuda.current_stream()
assert current == stream
logger.debug("%s %s has finished, stream: %s", trace_name,
logger.debug('%s %s has finished, stream: %s', trace_name,
name, stream)
finally:
streamqueue.task_done()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment