-
Notifications
You must be signed in to change notification settings - Fork 5.7k
add cpu random Generator #26013
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add cpu random Generator #26013
Changes from all commits
7d3e3a3
f31caa2
3810eaa
14eb642
531bd29
654aff4
3d2b02f
96909c4
3ebeec6
98ec770
ad16dad
14abbd1
e833115
8a6786d
f3cad72
4e99ac1
8f091d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <deque> | ||
#include <memory> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
#include <utility> | ||
|
||
#include "paddle/fluid/framework/generator.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
std::shared_ptr<Generator> Generator::gen_instance_ = NULL; | ||
|
||
GeneratorState* Generator::GetState() { | ||
std::lock_guard<std::mutex> lock(this->mutex); | ||
return this->state_.get(); | ||
} | ||
|
||
void Generator::SetState(GeneratorState* state_in) { | ||
std::lock_guard<std::mutex> lock(this->mutex); | ||
*this->state_ = *state_in; | ||
} | ||
|
||
uint64_t Generator::GetCurrentSeed() { | ||
std::lock_guard<std::mutex> lock(this->mutex); | ||
return this->state_->current_seed; | ||
} | ||
|
||
uint64_t Generator::Seed() { | ||
std::lock_guard<std::mutex> lock(this->mutex); | ||
uint64_t seed; | ||
std::random_device de; | ||
seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF; | ||
this->state_->current_seed = seed; | ||
std::seed_seq seq({seed}); | ||
this->state_->cpu_engine.seed(seq); | ||
|
||
return this->state_->current_seed; | ||
} | ||
|
||
void Generator::SetCurrentSeed(uint64_t seed) { | ||
std::lock_guard<std::mutex> lock(this->mutex); | ||
this->state_->current_seed = uint64_t(seed); | ||
std::seed_seq seq({seed}); | ||
this->state_->cpu_engine.seed(seq); | ||
} | ||
|
||
std::mt19937_64& Generator::GetCPUEngine() { | ||
std::lock_guard<std::mutex> lock(this->mutex); | ||
return this->state_->cpu_engine; | ||
} | ||
|
||
void Generator::SetCPUEngine(std::mt19937_64 engine) { | ||
std::lock_guard<std::mutex> lock(this->mutex); | ||
this->state_->cpu_engine = std::mt19937_64(engine); | ||
} | ||
|
||
uint64_t Generator::Random64() { | ||
std::lock_guard<std::mutex> lock(this->mutex); | ||
return this->state_->cpu_engine(); | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <stdint.h> | ||
#include <atomic> | ||
#include <deque> | ||
#include <iostream> // temp for debug | ||
#include <memory> | ||
#include <mutex> // NOLINT | ||
#include <random> | ||
#include <typeinfo> | ||
#include <utility> | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
struct GeneratorState { | ||
int64_t device = -1; | ||
uint64_t current_seed = 34342423252; | ||
std::mt19937_64 cpu_engine; | ||
}; | ||
|
||
struct Generator { | ||
Generator() { | ||
GeneratorState default_gen_state_cpu; | ||
default_gen_state_cpu.device = -1; | ||
default_gen_state_cpu.current_seed = 34342423252; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default initialize with fixed seed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should set a default one. maybe random? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. initialization with random seed seems more reasonable |
||
std::seed_seq seq({34342423252}); | ||
default_gen_state_cpu.cpu_engine = std::mt19937_64(seq); | ||
this->state_ = std::make_shared<GeneratorState>(default_gen_state_cpu); | ||
} | ||
explicit Generator(GeneratorState state_in) | ||
: state_{std::make_shared<GeneratorState>(state_in)} {} | ||
Generator(const Generator& other) | ||
: Generator(other, std::lock_guard<std::mutex>(other.mutex)) {} | ||
|
||
// get random state | ||
GeneratorState* GetState(); | ||
// set random state | ||
void SetState(GeneratorState* state_in); | ||
// get current seed | ||
uint64_t GetCurrentSeed(); | ||
// random a seed and get | ||
uint64_t Seed(); | ||
|
||
// set seed | ||
void SetCurrentSeed(uint64_t seed); | ||
// get cpu engine | ||
std::mt19937_64& GetCPUEngine(); | ||
// set cpu engine | ||
void SetCPUEngine(std::mt19937_64 engine); | ||
|
||
uint64_t Random64(); | ||
|
||
bool is_init_py = false; | ||
|
||
// CPU Generator singleton | ||
static std::shared_ptr<Generator> GetInstance() { | ||
if (NULL == gen_instance_) { | ||
gen_instance_.reset(new paddle::framework::Generator()); | ||
} | ||
return gen_instance_; | ||
} | ||
|
||
static std::shared_ptr<Generator> GetInstanceX() { | ||
if (NULL == gen_instance_) { | ||
gen_instance_.reset(new paddle::framework::Generator()); | ||
} | ||
gen_instance_->is_init_py = true; | ||
return gen_instance_; | ||
} | ||
|
||
private: | ||
static std::shared_ptr<Generator> gen_instance_; | ||
std::shared_ptr<GeneratorState> state_; | ||
mutable std::mutex mutex; | ||
|
||
Generator(const Generator& other, const std::lock_guard<std::mutex>&) | ||
: state_(std::make_shared<GeneratorState>(*(other.state_))) {} | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why note simply merge this into
Generator
as it is one-to-one "has a" relationship?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To let State-related value be in this class. we can easily and clearly add things in this class later.