from subprocess import check_output from threading import Timer from typing import Callable, List, Tuple def get_gpu_memory() -> List[int]: """ Get the used and total GPU memory (VRAM) in MiB :return memory_values: List of used and total GPU memory (VRAM) in MiB """ command = "nvidia-smi --query-gpu=memory.used,memory.total --format=csv,noheader,nounits" memory_info = check_output(command.split()).decode("ascii").replace("\r", "").split("\n")[:-1] memory_values = list(map(lambda x: tuple(map(int, x.split(","))), memory_info)) return memory_values class RepeatingTimer(Timer): def run(self): self.finished.wait(self.interval) while not self.finished.is_set(): self.function(*self.args, **self.kwargs) self.finished.wait(self.interval) gpu_memory_watcher: RepeatingTimer = None def watch_gpu_memory(interval: int = 1, callback: Callable[[List[Tuple[int, int]]], None] = None) -> RepeatingTimer: """ Start a repeating timer to watch the GPU memory usage :param interval: Interval in seconds :return timer: RepeatingTimer object """ global gpu_memory_watcher if gpu_memory_watcher is not None: raise RuntimeError("GPU memory watcher is already running") if callback is None: callback = print gpu_memory_watcher = RepeatingTimer(interval, lambda: callback(get_gpu_memory())) gpu_memory_watcher.start() return gpu_memory_watcher def stop_watcher(): global gpu_memory_watcher if gpu_memory_watcher is None: return gpu_memory_watcher.cancel() del gpu_memory_watcher gpu_memory_watcher = None if __name__ == "__main__": from time import sleep t = watch_gpu_memory() counter = 0 while True: sleep(1) counter += 1 if counter == 10: try: watch_gpu_memory() except RuntimeError: print("Got exception") pass elif counter >= 20: gpu_memory_watcher.cancel() break