Skip to content
Merged

Gna #2227

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 182 additions & 109 deletions src/sst/elements/GNA/GNA.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,75 +29,83 @@ using namespace SST::GNAComponent;
using namespace std;


GNA::GNA(ComponentId_t id, Params& params)
: Component(id),
state(IDLE),
now(0),
numFirings(0),
numDeliveries(0)
GNA::GNA (ComponentId_t id, Params & params)
: Component (id)
{
uint32_t outputLevel = params.find<uint32_t>("verbose", 0);
out.init("GNA:@p:@l: ", outputLevel, 0, Output::STDOUT);
now = 0;
neuronIndex = -1;
synapseIndex = -1;
syncSent = false;
numFirings = 0;
numDeliveries = 0;

uint32_t outputLevel = params.find<uint32_t> ("verbose", 0);
out.init ("GNA:@p:@l: ", outputLevel, 0, Output::STDOUT);

// get parameters
modelPath = params.find<string>("modelPath", "model");
steps = params.find<int> ("steps", 1000);
Neuron::dt = params.find<float> ("dt", 1);
InputsPerTic = params.find<int> ("InputsPerTic", 2);
STSDispatch = params.find<int> ("STSDispatch", 2);
STSParallelism = params.find<int> ("STSParallelism", 2);
maxOutMem = params.find<int> ("MaxOutMem", STSParallelism);
if (InputsPerTic < 1) out.fatal(CALL_INFO, -1, "InputsPerTic invalid\n");
if (STSDispatch < 1) out.fatal(CALL_INFO, -1, "STSDispatch invalid\n");
if (STSParallelism < 1) out.fatal(CALL_INFO, -1, "STSParallelism invalid\n");
if (maxOutMem < 1) out.fatal(CALL_INFO, -1, "MaxOutMem invalid\n");
modelPath = params.find<string>("modelPath", "model");
steps = params.find<int> ("steps", 1000);
Neuron::dt = params.find<float> ("dt", 1); // In seconds. Don't bother with UnitAlgebra because this is usually specified by wrapper script.
maxRequestDepth = params.find<int> ("maxRequestDepth", 2);

//set our clock
string clockFreq = params.find<string>("clock", "1GHz");
clockHandler = new Clock::Handler<GNA>(this, &GNA::clockTic);
clockTC = registerClock(clockFreq, clockHandler);
string clockFreq = params.find<string> ("clock", "1GHz");
clockTC = registerClock (clockFreq, new Clock::Handler<GNA> (this, &GNA::clockTic));

// tell the simulator not to end without us
registerAsPrimaryComponent();
primaryComponentDoNotEndSim();
registerAsPrimaryComponent ();
primaryComponentDoNotEndSim ();

// init memory
memory = loadUserSubComponent<Interfaces::StandardMem>("memory", ComponentInfo::SHARE_NONE, clockTC, new Interfaces::StandardMem::Handler<GNA>(this, &GNA::handleEvent));
if (!memory) {
params.insert("port", "mem_link");
memory = loadAnonymousSubComponent<Interfaces::StandardMem>("memHierarchy.standardInterface", "memory", 0,
ComponentInfo::SHARE_PORTS, params, clockTC, new Interfaces::StandardMem::Handler<GNA>(this, &GNA::handleEvent));
memory = loadUserSubComponent<Interfaces::StandardMem> (
"memory",
ComponentInfo::SHARE_NONE, clockTC,
new Interfaces::StandardMem::Handler<GNA> (this, &GNA::handleMemory)
);
if (!memory)
{
params.insert ("port", "mem_link");
memory = loadAnonymousSubComponent<Interfaces::StandardMem> (
"memHierarchy.standardInterface", "memory", 0,
ComponentInfo::SHARE_PORTS, params, clockTC,
new Interfaces::StandardMem::Handler<GNA>(this, &GNA::handleMemory)
);
}
if (!memory) out.fatal(CALL_INFO, -1, "Unable to load memHierarchy.standardInterface subcomponent\n");
if (!memory) out.fatal (CALL_INFO, -1, "Unable to load memHierarchy.standardInterface subcomponent\n");

link = loadUserSubComponent<Interfaces::SimpleNetwork> ("networkIF", ComponentInfo::SHARE_NONE, 1);
if (!link) out.fatal (CALL_INFO, 1, "No networkIF subcomponent\n");
link->setNotifyOnReceive (new Interfaces::SimpleNetwork::Handler<GNA> (this, &GNA::handleNetwork));
}

GNA::GNA()
GNA::GNA ()
: Component(-1)
{
// for serialization only
}

GNA::~GNA()
GNA::~GNA ()
{
for (auto n : neurons) delete n;
while (! networkRequests.empty ())
{
delete networkRequests.front ();
networkRequests.pop ();
}
}

void GNA::init(unsigned int phase)
void
GNA::init (unsigned int phase)
{
// init memory
memory->init(phase);
memory->init (phase);
link ->init (phase);

// Everything below we only do once
if (phase != 0) return;

// create STS units
for (int i = 0; i < STSParallelism; ++i) {
STSUnits.push_back(STS(this,i));
}

// Read data
// format:
// index,Vinit,Vthreshold,Vreset,leak,p -- neuron info; for input neurons, only specify index
// index,Vinit,Vthreshold,Vreset,leak,p -- neuron info; for input neurons, only specify index TODO: add Vbias
// to,weight,delay -- synapse info
// s<timing list> -- optional spike list for input neurons
// o -- optional output configuration
Expand Down Expand Up @@ -241,7 +249,7 @@ void GNA::init(unsigned int phase)
piece = strtok(0, ",");
float weight = atof(piece);
piece = strtok(0, ",");
int delay = atoi(piece) - 1; // -1 because spike delivery happens right after "now" advances.
int delay = atoi(piece);

if (n->synapseBase == 0)
{
Expand All @@ -261,95 +269,160 @@ void GNA::init(unsigned int phase)
}
}

int numNeurons = neurons.size();
int numNeurons = neurons.size ();
printf("Constructed %d neurons with %d links\n", numNeurons, countLinks);
}

void GNA::finish()
void
GNA::setup ()
{
for (auto i : Neuron::outputs) delete i.second; // flushes last row

printf("Completed %d neuron firings\n", numFirings);
printf("Completed %d spike deliveries\n", numDeliveries);
memory->setup ();
link->setup ();
}

// handle incoming memory
void GNA::handleEvent(Interfaces::StandardMem::Request * req)
void
GNA::complete (unsigned int phase)
{
map<uint64_t, STS*>::iterator i = requests.find(req->getID());
if (i == requests.end()) out.fatal(CALL_INFO, -1, "Request ID (%" PRIx64 ") not found in outstanding requests!\n", req->getID());

// handle event
STS * requestor = i->second;
requestor->returnRequest(req);
// clean up
requests.erase(i);
memory->complete (phase);
link->complete (phase);
}

void GNA::deliver(float val, int targetN, int time)
void
GNA::finish ()
{
// AFR: should really throttle this in some way
if (targetN >= neurons.size()) out.fatal(CALL_INFO, -1, "Invalid Neuron Address\n");
neurons[targetN]->deliverSpike(val, time);
numDeliveries++;
}
memory->finish ();
link ->finish ();
for (auto i : Neuron::outputs) delete i.second; // flushes last row

void GNA::readMem(Interfaces::StandardMem::Request *req, STS *requestor)
{
outgoingReqs.push(req); // queue the request to send later
requests.insert(make_pair(req->getID(), requestor)); // record who it came from
printf ("Completed %d neuron firings\n", numFirings);
printf ("Completed %d spike deliveries\n", numDeliveries);
}

void GNA::processFire()
// We simulate a von Nuemann style neuromorphic processor, working through our list of nuerons serially.
// We should execute one FLOP and one tightly-coupled-memory access per CPU cycle.
// Currently, for simplicity, we assume the full LIF model can execute in one cycle.
// Also, neuron load/save is not charged memory access time.
// Retrieving synapse records and transmitting spike packets can run in parallel with executing the LIF model,
// but the LIF model needs to stall until all spikes are sent.
bool
GNA::clockTic (Cycle_t t)
{
// assign neuron firings to lookup units (spike transfer structures)
int remainDispatches = STSDispatch;
for (auto & e : STSUnits) {
if (firedNeurons.empty()) break;
if (e.isFree()) {
e.assign(firedNeurons.front());
firedNeurons.pop_front();
remainDispatches--;
}
if (remainDispatches == 0) break;
using namespace Interfaces;
if (! networkRequests.empty ())
{
SimpleNetwork::Request * req = networkRequests.front ();
if (link->send (req, 0)) networkRequests.pop ();
}

// process neuron firings into activations
bool allSpikesDelivered = true;
for (auto & e : STSUnits) {
e.advance(now);
allSpikesDelivered &= e.isFree();
if (synapseIndex < 0) // Ready for next neuron.
{
int count = neurons.size ();
if (neuronIndex >= count) // Waiting for sync
{
if (syncSent) return false;
if (! memoryRequests.empty ()) return false; // Must finish all spikes before going to next cycle.

SyncEvent * event = new SyncEvent;
event->phase = 0;

SimpleNetwork::nid_t source = 0;
SimpleNetwork::nid_t dest = 0; // TODO: should broadcast to whole network
SimpleNetwork::Request * req = new SimpleNetwork::Request (dest, source, 1, false, false, event);
networkRequests.push (req);
syncSent = true;

return false;
}
syncSent = false; // Although this is a wasted operation most of the time, it's the simplest way to reset sync state.

neuronIndex++;
if (neuronIndex < count)
{
Neuron * n = neurons[neuronIndex];
if (n->update (now))
{
numFirings++;
if (n->synapseCount) synapseIndex = 0; // Start iterating through synapses.
}
}
}
else // Working through the synapse list for current neuron.
{
// Check if we're ready to send
if (networkRequests.size () >= maxRequestDepth) return false;
if (memoryRequests.size () >= maxRequestDepth) return false;

Neuron * n = neurons[neuronIndex];
uint64_t address = n->synapseBase + synapseIndex * sizeof (Synapse);
StandardMem::Read * req = new StandardMem::Read (address, sizeof (Synapse));
memory->send (req); // Unlike network, it seems that memory has unlimited capacity for requests.
memoryRequests.insert (address); // But we still limit the number of outstanding requests.

synapseIndex++;
if (synapseIndex >= n->synapseCount) synapseIndex = -1;
}

// do we move on?
if (allSpikesDelivered & firedNeurons.empty()) state = LIF;
return false; // keep going
}

bool GNA::clockTic(Cycle_t)
void
GNA::handleMemory (Interfaces::StandardMem::Request * req)
{
// send some outgoing mem reqs
for (int i = 0; i < maxOutMem && ! outgoingReqs.empty(); i++) {
memory->send(outgoingReqs.front());
outgoingReqs.pop();
}
SST::Interfaces::StandardMem::ReadResp * resp = dynamic_cast<SST::Interfaces::StandardMem::ReadResp *> (req);
assert (resp);
memoryRequests.erase (resp->pAddr);

Synapse * s = (Synapse *) &resp->data[0];
SpikeEvent * event = new SpikeEvent;
event->neuron = s->target;
event->weight = s->weight;
event->delay = s->delay;
delete req;

using namespace Interfaces;
SimpleNetwork::nid_t source = 0;
SimpleNetwork::nid_t dest = 0;
networkRequests.push (new SimpleNetwork::Request (dest, source, event->getSize (), false, false, event));
}

switch(state) {
case IDLE:
state = PROCESS_FIRE; // for now
break;
case PROCESS_FIRE:
processFire();
break;
case LIF:
for (auto n : neurons) if (n->update(now)) firedNeurons.push_back(n);
now++;
if (now >= steps) primaryComponentOKToEndSim();
state = PROCESS_FIRE;
numFirings += firedNeurons.size();
break;
default:
out.fatal(CALL_INFO, -1,"Invalid GNA state\n");
bool
GNA::handleNetwork (int vn)
{
// Ignore vn. It should always be 0 because that's all we registered for.

using namespace Interfaces;
while (SimpleNetwork::Request * req = link->recv (0))
{
Event * event = req->inspectPayload ();
if (SpikeEvent * spike = dynamic_cast<SpikeEvent *> (event))
{
if (spike->neuron >= neurons.size ()) out.fatal (CALL_INFO, -1, "Invalid Neuron Address\n");
neurons[spike->neuron]->deliverSpike (spike->weight, spike->delay+now);
numDeliveries++;
}
else if (SyncEvent * sync = dynamic_cast<SyncEvent *> (event))
{
now++;
neuronIndex = -1;
if (now >= steps) primaryComponentOKToEndSim ();
}
delete req;
}
return true;
}

return false; // keep going

// class SyncEvent -----------------------------------------------------------

uint32_t
SyncEvent::cls_id () const
{
return 1235;
}

string
SyncEvent::serialization_name () const
{
return "SyncEvent";
}

Loading