Spaces:
Runtime error
Runtime error
| """ | |
| Helpers for distributed training. | |
| """ | |
| import os | |
| import socket | |
| import torch as th | |
| import torch.distributed as dist | |
| from torch.distributed import barrier, is_initialized, broadcast | |
| # Change this to reflect your cluster layout. | |
| # The GPU for a given rank is (rank % GPUS_PER_NODE). | |
| GPUS_PER_NODE = 8 | |
| SETUP_RETRY_COUNT = 3 | |
| import datetime | |
| import os | |
| import socket | |
| from contextlib import closing | |
| def find_free_port() -> int: | |
| with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | |
| s.bind(("", 0)) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| return s.getsockname()[1] | |
| def check_if_port_open(port: int) -> bool: | |
| with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | |
| try: | |
| s.bind(("", port)) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| return True | |
| except OSError: | |
| return False | |
| def initialized(): | |
| return dist.is_initialized() | |
| def finalize(): | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| def initialize(): | |
| is_mpirun = not ( | |
| "RANK" in os.environ | |
| and "WORLD_SIZE" in os.environ | |
| and "MASTER_ADDR" in os.environ | |
| and "MASTER_PORT" in os.environ | |
| ) | |
| if is_mpirun: | |
| from mpi4py import MPI | |
| import subprocess | |
| comm = MPI.COMM_WORLD | |
| rank = comm.Get_rank() | |
| world_size = comm.Get_size() | |
| master_addr = None | |
| master_port = None | |
| if rank == 0: | |
| hostname_cmd = ["hostname -I"] | |
| result = subprocess.check_output(hostname_cmd, shell=True) | |
| master_addr = result.decode("utf-8").split()[0] | |
| base_port = os.environ.get( | |
| "MASTER_PORT", "29500" | |
| ) # TORCH_DISTRIBUTED_DEFAULT_PORT | |
| if check_if_port_open(int(base_port)): | |
| master_port = base_port | |
| else: | |
| master_port = find_free_port() | |
| master_addr = comm.bcast(master_addr, root=0) | |
| master_port = comm.bcast(master_port, root=0) | |
| # Determine local rank by assuming hostnames are unique | |
| proc_name = MPI.Get_processor_name() | |
| all_procs = comm.allgather(proc_name) | |
| local_rank = sum([i == proc_name for i in all_procs[:rank]]) | |
| uniq_proc_names = set(all_procs) | |
| host_rank = sorted(uniq_proc_names).index(proc_name) | |
| os.environ["LOCAL_RANK"] = str(local_rank) | |
| os.environ["HOST_RANK"] = str(host_rank) | |
| os.environ["NUM_HOSTS"] = str(len(uniq_proc_names)) | |
| os.environ["RANK"] = str(rank) | |
| os.environ["WORLD_SIZE"] = str(world_size) | |
| os.environ["MASTER_ADDR"] = master_addr | |
| os.environ["MASTER_PORT"] = str(master_port) | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| # Initialize torch distributed | |
| backend = "gloo" if not th.cuda.is_available() else "nccl" | |
| dist.init_process_group(backend=backend, timeout=datetime.timedelta(0, 3600)) | |
| th.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) | |
| if is_mpirun and dist.get_rank() == 0: | |
| print("Distributed setup") | |
| print("LOCAL_RANK", os.environ['LOCAL_RANK']) | |
| print("HOST_RANK", os.environ['HOST_RANK']) | |
| print("NUM_HOSTS", os.environ['NUM_HOSTS']) | |
| print("WORLD_SIZE", os.environ['WORLD_SIZE']) | |
| def local_host_gather(data): | |
| from mpi4py import MPI | |
| comm = MPI.COMM_WORLD | |
| host_rank = os.environ["HOST_RANK"] | |
| all_data = comm.allgather((host_rank, data)) | |
| return [d[1] for d in all_data if d[0] == host_rank] | |
| def in_distributed_mode(): | |
| return dist is not None | |
| def is_master(): | |
| return get_rank() == 0 | |
| def is_local_master(): | |
| return get_local_rank() == 0 | |
| def get_rank(): | |
| return dist.get_rank() if in_distributed_mode() else 0 | |
| def get_local_rank(): | |
| return int(os.environ["LOCAL_RANK"]) | |
| def worker_host_idx(): | |
| return int(os.environ["HOST_RANK"]) | |
| def num_hosts(): | |
| return int(os.environ['NUM_HOSTS']) | |
| def get_world_size(): | |
| return dist.get_world_size() if in_distributed_mode() else 1 | |
| def gpu_visible_device_list(): | |
| return str(dist.get_rank()) if in_distributed_mode() else None | |
| def get_device(): | |
| """ | |
| Get the device to use for torch.distributed. | |
| """ | |
| if th.cuda.is_available(): | |
| return th.device("cuda") | |
| return th.device("cpu") | |
| def sync_params(params): | |
| """ | |
| Synchronize a sequence of Tensors across ranks from rank 0. | |
| """ | |
| for p in params: | |
| with th.no_grad(): | |
| dist.broadcast(p, 0) | |
| def print0(*args, **kwargs): | |
| if get_rank() == 0: | |
| print(*args, **kwargs) | |
| def allreduce(t: th.Tensor, async_op=False): | |
| if dist.is_initialized(): | |
| if not t.is_cuda: | |
| cu = t.detach().cuda() | |
| ret = dist.all_reduce(cu, async_op=async_op) | |
| t.copy_(cu.cpu()) | |
| else: | |
| ret = dist.all_reduce(t, async_op=async_op) | |
| return ret | |
| return None | |
| def allgather(t: th.Tensor, cat=True): | |
| if dist.is_initialized(): | |
| if not t.is_cuda: | |
| t = t.cuda() | |
| ls = [th.empty_like(t) for _ in range(get_world_size())] | |
| dist.all_gather(ls, t) | |
| else: | |
| ls = [t] | |
| if cat: | |
| ls = th.cat(ls, dim=0) | |
| return ls | |