#pragma once
#include <cstdint>
#include <cstring>
#include <vector>
#include "ext/xxhash.h"
#include "Common/CommonFuncs.h"
#include "Common/Log.h"
template<class K>
inline uint32_t HashKey(const K &k) {
return XXH3_64bits(&k, sizeof(k)) & 0xFFFFFFFF;
}
template<class K>
inline bool KeyEquals(const K &a, const K &b) {
return !memcmp(&a, &b, sizeof(K));
}
enum class BucketState : uint8_t {
FREE,
TAKEN,
REMOVED,
};
template <class Key, class Value>
class DenseHashMap {
public:
DenseHashMap(int initialCapacity) : capacity_(initialCapacity) {
map.resize(initialCapacity);
state.resize(initialCapacity);
}
bool Get(const Key &key, Value *value) const {
uint32_t mask = capacity_ - 1;
uint32_t pos = HashKey(key) & mask;
uint32_t p = pos;
while (true) {
if (state[p] == BucketState::TAKEN && KeyEquals(key, map[p].key)) {
*value = map[p].value;
return true;
} else if (state[p] == BucketState::FREE) {
return false;
}
p = (p + 1) & mask;
if (p == pos) {
_assert_msg_(false, "DenseHashMap: Hit full on Get()");
}
}
return false;
}
Value GetOrNull(const Key &key) const {
Value value;
if (Get(key, &value)) {
return value;
} else {
return (Value)nullptr;
}
}
bool ContainsKey(const Key &key) const {
Value value;
return Get(key, &value);
}
bool Insert(const Key &key, Value value) {
if (count_ > capacity_ / 2) {
Grow(2);
}
uint32_t mask = capacity_ - 1;
uint32_t pos = HashKey(key) & mask;
uint32_t p = pos;
while (true) {
if (state[p] == BucketState::TAKEN) {
if (KeyEquals(key, map[p].key)) {
_assert_msg_(false, "DenseHashMap: Duplicate key of size %d inserted", (int)sizeof(Key));
return false;
}
} else {
break;
}
p = (p + 1) & mask;
if (p == pos) {
_assert_msg_(false, "DenseHashMap: Hit full on Insert()");
}
}
if (state[p] == BucketState::REMOVED) {
removedCount_--;
}
state[p] = BucketState::TAKEN;
map[p].key = key;
map[p].value = value;
count_++;
return true;
}
bool Remove(const Key &key) {
uint32_t mask = capacity_ - 1;
uint32_t pos = HashKey(key) & mask;
uint32_t p = pos;
while (state[p] != BucketState::FREE) {
if (state[p] == BucketState::TAKEN && KeyEquals(key, map[p].key)) {
state[p] = BucketState::REMOVED;
removedCount_++;
count_--;
return true;
}
p = (p + 1) & mask;
if (p == pos) {
_assert_msg_(false, "DenseHashMap: Hit full on Remove()");
}
}
return false;
}
size_t size() const {
return count_;
}
template<class T>
inline void Iterate(T func) const {
for (size_t i = 0; i < map.size(); i++) {
if (state[i] == BucketState::TAKEN) {
func(map[i].key, map[i].value);
}
}
}
template<class T>
inline void IterateMut(T func) {
for (size_t i = 0; i < map.size(); i++) {
if (state[i] == BucketState::TAKEN) {
func(map[i].key, map[i].value);
}
}
}
void Clear() {
memset(state.data(), (int)BucketState::FREE, state.size());
count_ = 0;
removedCount_ = 0;
}
void Rebuild() {
Grow(1);
}
void Maintain() {
if (removedCount_ >= capacity_ / 4) {
Rebuild();
}
}
private:
void Grow(int factor) {
std::vector<Pair> old = std::move(map);
std::vector<BucketState> oldState = std::move(state);
map.clear();
state.clear();
int oldCount = count_;
capacity_ *= factor;
map.resize(capacity_);
state.resize(capacity_);
count_ = 0;
removedCount_ = 0;
for (size_t i = 0; i < old.size(); i++) {
if (oldState[i] == BucketState::TAKEN) {
Insert(old[i].key, old[i].value);
}
}
_assert_msg_(oldCount == count_, "DenseHashMap: count should not change in Grow()");
}
struct Pair {
Key key;
Value value;
};
std::vector<Pair> map;
std::vector<BucketState> state;
int capacity_;
int count_ = 0;
int removedCount_ = 0;
};
template <class Value>
class PrehashMap {
public:
PrehashMap(int initialCapacity) : capacity_(initialCapacity) {
map.resize(initialCapacity);
state.resize(initialCapacity);
}
bool Get(uint32_t hash, Value *value) {
uint32_t mask = capacity_ - 1;
uint32_t pos = hash & mask;
uint32_t p = pos;
while (true) {
if (state[p] == BucketState::TAKEN && hash == map[p].hash) {
*value = map[p].value;
return true;
} else if (state[p] == BucketState::FREE) {
return false;
}
p = (p + 1) & mask;
if (p == pos) {
_assert_msg_(false, "PrehashMap: Hit full on Get()");
}
}
return false;
}
bool Insert(uint32_t hash, Value value) {
if (count_ > capacity_ / 2) {
Grow(2);
}
uint32_t mask = capacity_ - 1;
uint32_t pos = hash & mask;
uint32_t p = pos;
while (state[p] != BucketState::FREE) {
if (state[p] == BucketState::TAKEN) {
if (hash == map[p].hash)
return false;
} else {
break;
}
p = (p + 1) & mask;
if (p == pos) {
_assert_msg_(false, "PrehashMap: Hit full on Insert()");
}
}
if (state[p] == BucketState::REMOVED) {
removedCount_--;
}
state[p] = BucketState::TAKEN;
map[p].hash = hash;
map[p].value = value;
count_++;
return true;
}
bool Remove(uint32_t hash) {
uint32_t mask = capacity_ - 1;
uint32_t pos = hash & mask;
uint32_t p = pos;
while (state[p] != BucketState::FREE) {
if (state[p] == BucketState::TAKEN && hash == map[p].hash) {
state[p] = BucketState::REMOVED;
removedCount_++;
count_--;
return true;
}
p = (p + 1) & mask;
if (p == pos) {
_assert_msg_(false, "PrehashMap: Hit full on Remove()");
}
}
return false;
}
size_t size() {
return count_;
}
template<class T>
void Iterate(T func) const {
for (size_t i = 0; i < map.size(); i++) {
if (state[i] == BucketState::TAKEN) {
func(map[i].hash, map[i].value);
}
}
}
void Clear() {
memset(state.data(), (int)BucketState::FREE, state.size());
count_ = 0;
removedCount_ = 0;
}
void Rebuild() {
Grow(1);
}
void Maintain() {
if (removedCount_ >= capacity_ / 4) {
Rebuild();
}
}
private:
void Grow(int factor) {
std::vector<Pair> old = std::move(map);
std::vector<BucketState> oldState = std::move(state);
map.clear();
state.clear();
int oldCount = count_;
int oldCapacity = capacity_;
capacity_ *= factor;
map.resize(capacity_);
state.resize(capacity_);
count_ = 0;
removedCount_ = 0;
for (size_t i = 0; i < old.size(); i++) {
if (oldState[i] == BucketState::TAKEN) {
Insert(old[i].hash, old[i].value);
}
}
INFO_LOG(Log::G3D, "Grew hashmap capacity from %d to %d", oldCapacity, capacity_);
_assert_msg_(oldCount == count_, "PrehashMap: count should not change in Grow()");
}
struct Pair {
uint32_t hash;
Value value;
};
std::vector<Pair> map;
std::vector<BucketState> state;
int capacity_;
int count_ = 0;
int removedCount_ = 0;
};