From 7678c3684da5f5f6186cf46e733ba5e37a87fe54 Mon Sep 17 00:00:00 2001 From: amaurya Date: Thu, 22 Feb 2024 15:58:10 +0000 Subject: [PATCH] Starting cleanup --- csrc/veloc/deepspeed_py_veloc.cpp | 17 ---- csrc/veloc/memory_cache.cpp | 68 +------------- csrc/veloc/memory_cache.hpp | 2 - .../async_checkpoint_engine_old.py | 94 ------------------- .../veloc_checkpoint_engine.py | 17 +--- deepspeed/runtime/engine.py | 5 - 6 files changed, 4 insertions(+), 199 deletions(-) delete mode 100644 csrc/veloc/deepspeed_py_veloc.cpp delete mode 100644 deepspeed/runtime/checkpoint_engine/async_checkpoint_engine_old.py diff --git a/csrc/veloc/deepspeed_py_veloc.cpp b/csrc/veloc/deepspeed_py_veloc.cpp deleted file mode 100644 index d2a809333720..000000000000 --- a/csrc/veloc/deepspeed_py_veloc.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include - -void veloc_ckpt_t::ckpt(ckpt_map_t &m, std::string path) { - c10::impl::GenericDict generic_dict(c10::StringType::get(), m.begin()->second.type()); - for (const auto& entry : m) { - generic_dict.insert(entry.first, entry.second); - } - torch::serialize::OutputArchive output_archive; - output_archive.write("data", generic_dict); - output_archive.save_to(path); - std::cout << "Saving complete from CPP size for " << path << std::endl; - return; -} - -void veloc_ckpt_t::wait(size_t tensor_id) { - std::cout << "Not implemented yet" << std::endl; -} diff --git a/csrc/veloc/memory_cache.cpp b/csrc/veloc/memory_cache.cpp index 04894ca4fe8f..d9566454e98d 100644 --- a/csrc/veloc/memory_cache.cpp +++ b/csrc/veloc/memory_cache.cpp @@ -3,8 +3,8 @@ memory_cache_t::memory_cache_t(int d, size_t t, int r): _device_id(d), _total_memory(t), _curr_size(0), _head(0), _tail(0), _rank(r) { try { is_active = true; + checkCuda(cudaMallocHost(&_start_ptr, _total_memory)) max_allocated = 0; - malloc_thread = std::thread([&] { allocate_pin_mem(); }); DBG("Returned from the memory cache function..... "); } catch (std::exception &e) { FATAL("Exception caught in memory cache constructor." << e.what()); @@ -33,7 +33,6 @@ void memory_cache_t::shutdown() { _mem_lock.unlock(); _mem_cv.notify_all(); DBG("[" << _rank << "]" << "Memory cache shutdown starting"); - malloc_thread.join(); DBG("[" << _rank << "]" <<"Memory cache shutdown complete"); } catch (std::exception &e) { FATAL("Exception caught in memory cache destructor." << e.what()); @@ -42,67 +41,6 @@ void memory_cache_t::shutdown() { } } -void memory_cache_t::allocate_pin_mem() { - try{ - TIMER_START(alloc_start); - - checkCuda(cudaSetDevice(_device_id)); - checkCuda(cudaFree(0)); - int posix_memalign_result = posix_memalign((void **)&_start_ptr, HUGEPAGES_SIZE, _total_memory); - madvise(_start_ptr, _total_memory, MADV_HUGEPAGE); - if (posix_memalign_result != 0) { - FATAL("Error allocating hugepages: " << posix_memalign_result); - } - std::unique_lock _mem_lock(_mem_mutex, std::defer_lock); - omp_set_num_threads(MALLOC_THREADS); - TIMER_START(alloc_starting); - size_t rem = _total_memory; - while (max_allocated != _total_memory && is_active) { - size_t chunk = MIN_CHUNK_SIZE < rem ? MIN_CHUNK_SIZE : rem; - #pragma omp parallel - { - char *buf = (char *)_start_ptr+max_allocated; - int id = omp_get_thread_num(); - int num = omp_get_num_threads(); - size_t my_size = chunk/num; - size_t my_start = id*my_size; - // if (touch_pages) { - // #pragma omp parallel for - // { - // for(size_t i = 0; i < my_size; i+=HUGEPAGES_SIZE) - for(size_t i = 0; i < my_size; i+=CACHE_LINE_SIZE) - buf[my_start + i] = 0x00; - // } - // } else { - // memset(buf+my_start, 0, my_size); - // } - } - _mem_lock.lock(); - max_allocated += chunk; - _mem_lock.unlock(); - _mem_cv.notify_all(); - rem -= chunk; - } - if (!is_active) - return; - if (max_allocated != _total_memory) - FATAL("Max allocated is not same as total " << max_allocated << " total was " << _total_memory); - TIMER_STOP(alloc_starting, "Simple malloc and touch done", _total_memory); - TIMER_START(pinning); - _mem_lock.lock(); - checkCuda(cudaHostRegister(_start_ptr, _total_memory, cudaHostRegisterPortable)); - _mem_lock.unlock(); - _mem_cv.notify_all(); - TIMER_STOP(pinning, "Time to pin memory", _total_memory); - TIMER_STOP(alloc_start, "Host memory allocation time on device " << _rank, _total_memory); - return; - } catch (std::exception &e) { - FATAL("Exception caught in allocate pin memory." << e.what()); - } catch (...) { - FATAL("Unknown exception caught in allocate pin memory."); - } -} - mem_region_t* memory_cache_t::_assign(const uint64_t uid, size_t h, size_t s) { try { if (h+s > _total_memory) { @@ -134,9 +72,8 @@ mem_region_t* memory_cache_t::allocate(const uint64_t uid, size_t s) { FATAL("[" << _rank << "]" <<"Cannot allocate size " << s << " larger than the pool of " << _total_memory); mem_region_t* ptr = nullptr; std::unique_lock _mem_lock(_mem_mutex); - while(((max_allocated < _total_memory) || (_curr_size + s > _total_memory)) && is_active) { + while((_curr_size + s > _total_memory) && is_active) _mem_cv.wait(_mem_lock); - } if (!is_active) { _mem_lock.unlock(); _mem_cv.notify_all(); @@ -195,7 +132,6 @@ void memory_cache_t::deallocate(uint64_t _uid, size_t s) { std::cout << "Tried deleting " << _uid << " of size " << s << " at offset " << m->start_offset << " but front element was " << (void *)m->ptr << " of size " << m->end_offset-m->start_offset << std::endl; _print_trace(); - // std::abort(); return; } std::unique_lock _mem_lock(_mem_mutex); diff --git a/csrc/veloc/memory_cache.hpp b/csrc/veloc/memory_cache.hpp index 35ef169203f4..20cbb70f2c6a 100644 --- a/csrc/veloc/memory_cache.hpp +++ b/csrc/veloc/memory_cache.hpp @@ -77,14 +77,12 @@ class memory_cache_t { std::mutex _mem_mutex; std::condition_variable _mem_cv; std::deque _mem_q; - std::thread malloc_thread; size_t max_allocated = 0; bool is_active = true; int _rank = -1; public: memory_cache_t(int d, size_t t, int rank); ~memory_cache_t(); - void allocate_pin_mem(); mem_region_t* _assign(const uint64_t uid, size_t h, size_t s); mem_region_t* allocate(const uint64_t uid, size_t s); void deallocate(uint64_t _uid, size_t s); diff --git a/deepspeed/runtime/checkpoint_engine/async_checkpoint_engine_old.py b/deepspeed/runtime/checkpoint_engine/async_checkpoint_engine_old.py deleted file mode 100644 index f7232349714f..000000000000 --- a/deepspeed/runtime/checkpoint_engine/async_checkpoint_engine_old.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import torch -from deepspeed.utils import logger, log_dist, instrument_w_nvtx -from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \ - CheckpointEngine -import time -from threading import Thread, Condition, Lock -from concurrent.futures import ThreadPoolExecutor -import copy - -class AsyncCheckpointEngine(CheckpointEngine): - - def __init__(self, config_params, r): - super().__init__(config_params, r) - self.checkpoint_in_progress = False - self.rank = r - in_progress_lock = Lock() - self.in_progress_cv = Condition(lock=in_progress_lock) - self.futures = None - self.executor = ThreadPoolExecutor(max_workers=1) - - def create(self, tag): - log_dist(f"[AsyncTorch] Checkpoint {tag} is about to be saved!", ranks=[0]) - - @instrument_w_nvtx - def _to_cpu(self, ele, snapshot): - try: - if torch.is_tensor(ele) and ele.device.type=='cuda': - snapshot = ele.cpu() - # elif isinstance(ele, dict) and not isinstance(ele, OrderedDict): - elif isinstance(ele, dict): - snapshot = {} - for (k, v) in ele.items(): - snapshot[k] = None - snapshot[k] = self._to_cpu(v, snapshot[k]) - elif isinstance(ele, list): - snapshot = [None for _ in range(len(ele))] - for (idx, v) in enumerate(ele): - snapshot[idx] = self._to_cpu(v, snapshot[idx]) - else: - log_dist(f"[AsyncTorch] Got in parse dict of type {type(ele)}: {ele}") - snapshot = copy.deepcopy(ele) - return snapshot - except Exception as exc: - logger.info(f"[AsyncTorch] From _to_cpu, generated exception: {exc}") - - def _background_save(self, state_dict, path): - t = time.time() - with self.in_progress_cv: - self.checkpoint_in_progress = True - self.in_progress_cv.notify_all() - torch.save(state_dict, path) - with self.in_progress_cv: - self.checkpoint_in_progress = False - self.in_progress_cv.notify_all() - logger.info(f"[AsyncTorch] Time to complete background save {time.time()-t} for path {path}") - - @instrument_w_nvtx - def save(self, state_dict, path: str): - logger.info(f"[AsyncTorch] Saving {path}...") - t = time.time() - while self.checkpoint_in_progress: - with self.in_progress_cv: - self.in_progress_cv.wait() - logger.info(f"[AsyncTorch] Prev completion waiting time {time.time()-t} for incoming {path}...") - new_state_dict = {} - new_state_dict = self._to_cpu(state_dict, new_state_dict) - logger.info(f"[AsyncTorch] To CPU snapshot time {time.time()-t} for incoming {path}...") - # torch.save(state_dict, path) - ts = time.time() - self.executor.submit(self._background_save, new_state_dict, path) - logger.info(f"[AsyncTorch] Time to submit to background save {time.time()-ts}") - logger.info(f"[AsyncTorch] Saved {path}. in time {time.time()-t}") - return None - - def load(self, path: str, map_location=None): - logger.info(f"[AsyncTorch] Loading checkpoint from {path}...") - partition = torch.load(path, map_location=map_location) - logger.info(f"[AsyncTorch] Loaded checkpoint from {path}.") - return partition - - def commit(self, tag): - logger.info(f"[AsyncTorch] Checkpoint {tag} is ready now!") - return True - - def wait(self, prev_version = -1): - return True - - def shutdown(self): - return True \ No newline at end of file diff --git a/deepspeed/runtime/checkpoint_engine/veloc_checkpoint_engine.py b/deepspeed/runtime/checkpoint_engine/veloc_checkpoint_engine.py index 3a9b0a2cf9f5..42f6f269b3cd 100644 --- a/deepspeed/runtime/checkpoint_engine/veloc_checkpoint_engine.py +++ b/deepspeed/runtime/checkpoint_engine/veloc_checkpoint_engine.py @@ -43,7 +43,6 @@ class VELOCCheckpointEngine(CheckpointEngine): def __init__(self, config_params, r): try: - # t = time.time() super().__init__(config_params, r) self.rank = r self.ckpt_engine = VelocCkptBuilder().load().veloc_ckpt_handle( @@ -54,7 +53,6 @@ def __init__(self, config_params, r): ) self.futures = deque() self.executor = ThreadPoolExecutor(max_workers=int(config_params["writer_threads"])) - # print(f"[VELOC] Init took {time.time()-t}") except Exception as exc2: print("[ERROR]Got exception during VELOC init ", exc2) sys.exit(-1) @@ -66,12 +64,9 @@ def create(self, tag): # @instrument_w_nvtx def _parse_dict(self, ele, snapshot, async_copies_list): try: - if isinstance(ele, np.ndarray): # and ele.nbytes > ASYNC_CKPT_SIZE_MIN: - print("Got a numpy array") - # import pdb; pdb.set_trace(); + if isinstance(ele, np.ndarray): data_device = -1 snapshot = f"{len(async_copies_list)}-pickled-numpy" - # Storing in async_copies_list values: data_ptr, size_in_bytes, device_id, file_offset async_copies_list.append([ele.ctypes.data, ele.nbytes, -1, 0]) elif torch.is_tensor(ele) and ele.device.type == 'cuda': if (ele.numel()*ele.element_size() > ASYNC_CKPT_SIZE_MIN): @@ -91,7 +86,7 @@ def _parse_dict(self, ele, snapshot, async_copies_list): snapshot[idx], async_copies_list = self._parse_dict(v, snapshot[idx], async_copies_list) else: log_dist(f"[VELOC] Got in parse dict of type {type(ele)}: {ele}") - snapshot = ele # copy.deepcopy(ele) + snapshot = ele return snapshot, async_copies_list except Exception as exc: logger.info(f"[VELOC][ERROR] From _to_cpu, generated exception: {exc}") @@ -132,9 +127,6 @@ def save_background(self, state_dict, path: str): file.write(str(len(headers)).encode("utf-8")) file.write(headers) file.write(serialized_dict) - - # logger.info(f"[VELOC] In background meta-data thread saved {path} in time {time.time()-start_time}") - # sys.stdout = redirect._stdout return None except Exception as exc: logger.info(f"[VELOC][ERROR] From VELOC save_background, generated exception: {exc}") @@ -142,16 +134,12 @@ def save_background(self, state_dict, path: str): def save(self, state_dict, path: str): try: - # start_time = time.time() f = self.executor.submit(self.save_background, state_dict, path) self.futures.append(f) - # logger.info(f"[VELOC] Saved {path}. in time {time.time()-start_time}") return True except Exception as exc: logger.info(f"[VELOC][ERROR] From save, generated exception: {exc}") sys.exit(-1) - - def load(self, path: str, map_location=None): logger.info(f"[VELOC] Loading checkpoint from {path}...") @@ -160,7 +148,6 @@ def load(self, path: str, map_location=None): return partition def commit(self, tag): - # self.ckpt_engine.wait(-1) logger.info(f"[VELOC] Checkpoint {tag} is ready now!") return True diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 8e134276c6e9..ad84cc647fd2 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -957,11 +957,6 @@ def _configure_checkpointing(self, dist_init_required): from deepspeed.runtime.checkpoint_engine.torch_sn_async_checkpoint_engine import TSNAsyncCheckpointEngine self.checkpoint_engine = TSNAsyncCheckpointEngine(self._config.torch_sn_async_ckpt_config, self.global_rank) - - if self._config is not None and self._config.veloc_config: - from deepspeed.runtime.checkpoint_engine.veloc_checkpoint_engine import \ - VELOCCheckpointEngine - self.checkpoint_engine = VELOCCheckpointEngine({"rank": rank, "local_rank": self.local_rank, "device": torch.cuda.current_device(), "aio_config": self._config.aio_config}) # only the first data parallel process needs to store the model checkpoint # if you want to use node local storage this must be done by rank 0 on each # node