diff --git a/mmdet/utils/contextmanagers.py b/mmdet/utils/contextmanagers.py index 12073bef93219fedd96713aa7c6452c1530679a5..0363f0145af80221babc53e8869a0c0d519cc31e 100644 --- a/mmdet/utils/contextmanagers.py +++ b/mmdet/utils/contextmanagers.py @@ -14,8 +14,8 @@ DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False)) @contextlib.asynccontextmanager -async def completed(trace_name="", - name="", +async def completed(trace_name='', + name='', sleep_interval=0.05, streams: List[torch.cuda.Stream] = None): """ @@ -42,7 +42,7 @@ async def completed(trace_name="", stream_before_context_switch.record_event(start) cpu_start = time.monotonic() - logger.debug("%s %s starting, streams: %s", trace_name, name, streams) + logger.debug('%s %s starting, streams: %s', trace_name, name, streams) grad_enabled_before = torch.is_grad_enabled() try: yield @@ -60,37 +60,41 @@ async def completed(trace_name="", # observed change of torch.is_grad_enabled() during concurrent run of # async_test_bboxes code - assert grad_enabled_before == grad_enabled_after, \ - "Unexpected is_grad_enabled() value change" + assert (grad_enabled_before == grad_enabled_after + ), 'Unexpected is_grad_enabled() value change' are_done = [e.query() for e in end_events] - logger.debug("%s %s completed: %s streams: %s", trace_name, name, + logger.debug('%s %s completed: %s streams: %s', trace_name, name, are_done, streams) with torch.cuda.stream(stream_before_context_switch): while not all(are_done): await asyncio.sleep(sleep_interval) are_done = [e.query() for e in end_events] - logger.debug("%s %s completed: %s streams: %s", trace_name, - name, are_done, streams) + logger.debug( + '%s %s completed: %s streams: %s', + trace_name, + name, + are_done, + streams, + ) current_stream = torch.cuda.current_stream() assert current_stream == stream_before_context_switch if DEBUG_COMPLETED_TIME: cpu_time = (cpu_end - cpu_start) * 1000 - stream_times_ms = "" + stream_times_ms = '' for i, stream in enumerate(streams): elapsed_time = start.elapsed_time(end_events[i]) - stream_times_ms += " {stream} {elapsed_time:.2f} ms".format( - stream, elapsed_time) - logger.info("{trace_name} {name} cpu_time {cpu_time:.2f} ms", - trace_name, name, cpu_time, stream_times_ms) + stream_times_ms += ' {} {:.2f} ms'.format(stream, elapsed_time) + logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time, + stream_times_ms) @contextlib.asynccontextmanager async def concurrent(streamqueue: asyncio.Queue, - trace_name="concurrent", - name="stream"): + trace_name='concurrent', + name='stream'): """Run code concurrently in different streams. :param streamqueue: asyncio.Queue instance. @@ -110,12 +114,12 @@ async def concurrent(streamqueue: asyncio.Queue, try: with torch.cuda.stream(stream): - logger.debug("%s %s is starting, stream: %s", trace_name, name, + logger.debug('%s %s is starting, stream: %s', trace_name, name, stream) yield current = torch.cuda.current_stream() assert current == stream - logger.debug("%s %s has finished, stream: %s", trace_name, + logger.debug('%s %s has finished, stream: %s', trace_name, name, stream) finally: streamqueue.task_done()