代码拉取完成,页面将自动刷新
//
// Created by hzy on 2024/2/6.
//
#include <algorithm>
#include <bitset>
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <regex>
#include <set>
#include <vector>
#include <c10/util/irange.h>
#include "PluggableAllocator.h"
void local_raw_delete(void *ptr) {
PluggableAllocator::getInstance().free(ptr);
}
void PluggableAllocator::add_allocated_block(Block *block) {
std::lock_guard<std::mutex> lock(mutex);
allocated_blocks[block->ptr] = block;
}
std::mutex *PluggableAllocator::getFreeMutex() const {
return &npu_free_mutex;
}
Block *PluggableAllocator::get_allocated_block(void *ptr, bool remove) {
std::lock_guard<std::mutex> lock(mutex);
auto it = allocated_blocks.find(ptr);
if (it == allocated_blocks.end()) {
return nullptr;
}
Block *block = it->second;
if (remove) {
allocated_blocks.erase(it);
}
return block;
}
void PluggableAllocator::init(int device_count) {
int size = static_cast<int>(device_allocator.size());
if (size < device_count) {
device_allocator.resize(device_count);
for (const auto i: c10::irange(size, device_count)) {
device_allocator[i] = std::make_unique<DeviceCachingAllocator>();
}
}
}
bool PluggableAllocator::initialized() {
return !device_allocator.empty();
}
/** allocates a block which is safe to use from the provided stream */
void *PluggableAllocator::malloc(int device, size_t size, aclrtStream stream) {
Block *block = device_allocator[device]->malloc(device, size, stream);
add_allocated_block(block);
void *devPtr = static_cast<void *>(block->ptr);
return devPtr;
}
void PluggableAllocator::free(void *ptr) {
if (!ptr) {
return;
}
Block *block = get_allocated_block(ptr, true);
if (!block) {
AT_ERROR("invalid device pointer: ", ptr);
}
device_allocator[block->device]->free(block);
}
void PluggableAllocator::setMemoryFraction(double fraction, int device) {
TORCH_INTERNAL_ASSERT(
0 <= device && device < device_allocator.size(),
"Allocator not initialized for device ",
device,
": did you call init?");
TORCH_INTERNAL_ASSERT(
0 <= fraction && fraction <= 1,
"invalid fraction:",
fraction,
". Please set within (0, 1).");
c10_npu::SetDevice(device);
device_allocator[device]->setMemoryFraction(fraction);
}
void PluggableAllocator::emptyCache(bool check_error) {
int count = static_cast<int>(device_allocator.size());
for (int i = 0; i < count; i++)
device_allocator[i]->emptyCache(check_error);
}
void PluggableAllocator::setShutdownStats() {
int count = static_cast<int>(device_allocator.size());
for (int i = 0; i < count; i++) {
device_allocator[i]->devSetShutdownStats();
}
}
void *PluggableAllocator::getBaseAllocation(void *ptr, size_t *outSize) {
Block *block = get_allocated_block(ptr);
if (!block) {
AT_ERROR("invalid device pointer: ", ptr);
}
return device_allocator[block->device]->getBaseAllocation(block, outSize);
}
void PluggableAllocator::recordStream(const c10::DataPtr &ptr, c10_npu::NPUStream stream) {
// Empty tensor's storage().data() might be a null ptr. As there is no
// blocks associated with those tensors, it is fine to do nothing here.
if (!ptr.get()) {
return;
}
// If a tensor is not allocated by this instance, simply skip
// This usually happens when NPU tensors are shared across processes,
// we have implemented reference counting based sharing mechanism to
// guarantee tensors won't be accidentally freed by one process while
// they are still being used in another
if (ptr.get_deleter() != &local_raw_delete) {
return;
}
Block *block = get_allocated_block(ptr.get());
// block must not be null reaching here
TORCH_INTERNAL_ASSERT(block != nullptr, "No allocated block can be found");
device_allocator[block->device]->recordStream(block, stream);
}
void PluggableAllocator::eraseStream(const c10::DataPtr &ptr, c10_npu::NPUStream stream) {
if (!ptr.get()) {
return;
}
// If a tensor is not allocated by this instance, simply skip
// This usually happens when NPU tensors are shared across processes,
// we have implemented reference counting based sharing mechanism to
// guarantee tensors won't be accidentally freed by one process while
// they are still being used in another
if (ptr.get_deleter() != &local_raw_delete) {
// TORCH_NPU_WARN_ONCE("Tensor not is not allocated by NPUCachingAllocator, skip eraseStream.");
return;
}
Block *block = get_allocated_block(ptr.get());
if (!block) {
AT_ERROR("invalid device pointer: ", ptr.get());
}
if (block->stream != c10_npu::getCurrentNPUStream(block->device).stream(false)) {
// If the Stream applying for tensor block different from
// the stream of submiting event wait task in HCCL synchronize()
// method, the recordSteam can not be erased.
// New tensor creation may use the block before HCCL op is complete.
return;
}
device_allocator[block->device]->eraseStream(block, stream);
}
std::vector<SegmentInfo> PluggableAllocator::snapshot() {
std::vector<SegmentInfo> result;
int count = static_cast<int>(device_allocator.size());
for (int i = 0; i < count; i++) {
auto snap = device_allocator[i]->snapshot();
result.insert(result.end(), snap.begin(), snap.end());
}
return result;
}
c10::DeleterFnPtr PluggableAllocator::raw_deleter() const {
return &local_raw_delete;
}
void PluggableAllocator::cacheInfo(int dev_id, size_t *cachedAndFree, size_t *largestBlock) {
device_allocator[dev_id]->cacheInfo(cachedAndFree, largestBlock);
}
void PluggableAllocator::assertValidDevice(int device) {
int device_num = c10_npu::device_count();
AT_ASSERTM(0 <= device && device < device_num, "Invalid device argument.");
}
DeviceStats PluggableAllocator::getDeviceStats(int device) {
assertValidDevice(device);
return device_allocator[device]->getStats();
}
void PluggableAllocator::resetAccumulatedStats(int device) {
assertValidDevice(device);
device_allocator[device]->resetAccumulatedStats();
}
void PluggableAllocator::resetPeakStats(int device) {
assertValidDevice(device);
device_allocator[device]->resetPeakStats();
}
void PluggableAllocator::raw_delete(void *ptr) {
this->free(ptr);
}
void PluggableAllocator::FreeDeviceCachedMemory(int device) {
device_allocator[device]->emptyCache(true);
}
std::string PluggableAllocator::name() {
return "native";
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。