From 5b2c316890b7b8af95f0d0334d1fd34b9a687b99 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 13:08:54 +0300 Subject: [PATCH] eliminate duplicated code from #5095 --- modules/devices.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 93d82bb..dd50fe2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -24,17 +24,18 @@ def extract_device_id(args, name): return None +def get_cuda_device_string(): + from modules import shared + + if shared.cmd_opts.device_id is not None: + return f"cuda:{shared.cmd_opts.device_id}" + + return "cuda" + + def get_optimal_device(): if torch.cuda.is_available(): - from modules import shared - - device_id = shared.cmd_opts.device_id - - if device_id is not None: - cuda_device = f"cuda:{device_id}" - return torch.device(cuda_device) - else: - return torch.device("cuda") + return torch.device(get_cuda_device_string()) if has_mps(): return torch.device("mps") @@ -44,16 +45,7 @@ def get_optimal_device(): def torch_gc(): if torch.cuda.is_available(): - from modules import shared - - device_id = shared.cmd_opts.device_id - - if device_id is not None: - cuda_device = f"cuda:{device_id}" - else: - cuda_device = "cuda" - - with torch.cuda.device(cuda_device): + with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect()