Skip to content

Commit 1367741

Browse files
committed
cache segment activation states. implement hashCode,equals on Cell to ensure proper reproducibility
1 parent 0218b4d commit 1367741

File tree

4 files changed

+102
-32
lines changed

4 files changed

+102
-32
lines changed

src/HTM/java/htm/Cell.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public class Cell extends AbstractCell {
1313

1414
private final Column _column;
1515
private final int _index;
16+
private final int _id;
1617
private boolean _isActive;
1718
private boolean _wasActive;
1819
private boolean _isPredicting;
@@ -40,6 +41,24 @@ public class Cell extends AbstractCell {
4041
_wasPredicted = false;
4142
_isLearning = false;
4243
_wasLearning = false;
44+
int cpc = _column.getRegion().getCellsPerCol();
45+
_id = (_column.cx()*cpc + _index) +
46+
(_column.cy()*cpc*_column.getRegion().getWidth());
47+
}
48+
49+
@Override
50+
public int hashCode() {
51+
return _id;
52+
}
53+
54+
@Override
55+
public boolean equals(Object obj) {
56+
if(obj instanceof Cell) {
57+
Cell cell = (Cell)obj;
58+
if(cell.getRegion()==getRegion())
59+
return cell._id==_id;
60+
}
61+
return false;
4362
}
4463

4564
public Region getRegion() {
@@ -76,6 +95,10 @@ public int gridIndex() {
7695
return _column.gridIndex();
7796
}
7897

98+
public int getId() {
99+
return _id;
100+
}
101+
79102
/** Return the Cell's index position within its Column. */
80103
public int getIndex() { return _index; }
81104
@Override
@@ -159,6 +182,8 @@ void nextTimeStep() {
159182
_isActive = false;
160183
_isPredicting = false;
161184
_isLearning = false;
185+
for(Segment seg : _segments)
186+
seg.nextTimeStep();
162187
}
163188

164189
/**
@@ -255,10 +280,8 @@ SegmentUpdateInfo updateSegmentActiveSynapses(boolean previous,
255280
Segment segment, boolean newSynapses) {
256281
Set<Synapse> activeSyns = new HashSet<Synapse>();
257282
if(segment!=null) {
258-
if(previous)
259-
segment.getPrevActiveSynapses(activeSyns);
260-
else
261-
segment.getActiveSynapses(activeSyns);
283+
activeSyns = previous ? segment.getPrevActiveSynapses() :
284+
segment.getActiveSynapses();
262285
}
263286

264287
SegmentUpdateInfo segmentUpdate =

src/HTM/java/htm/Column.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ public CellAndSegment(Cell cell, Segment segment) {
6262
*/
6363
Column(Region region, int srcPosX, int srcPosY, int posX, int posY) {
6464
_region = region;
65+
_ix = srcPosX; //'input' row and col
66+
_iy = srcPosY;
67+
_cx = posX; //'column grid' row and col
68+
_cy = posY;
69+
6570
int numCells = region.getCellsPerCol();
6671
_cells = new Cell[numCells];
6772
for(int i=0; i<numCells; ++i)
@@ -85,10 +90,6 @@ public CellAndSegment(Cell cell, Segment segment) {
8590
_overlapDutyCycle = 1.0f;
8691

8792
_overlap = 0; //the last computed input overlap for the Column.
88-
_ix = srcPosX; //'input' row and col
89-
_iy = srcPosY;
90-
_cx = posX; //'column grid' row and col
91-
_cy = posY;
9293
}
9394

9495
public int ix() { return _ix; }

src/HTM/java/htm/Region.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ public Region(int inputSizeX, int inputSizeY, int localityRadius,
149149
_columns = new Column[_width*_height];
150150
for(int cx=0; cx<_width; ++cx) {
151151
for(int cy=0; cy<_height; ++cy) {
152-
_columns[(cy*_height)+cx] = new Column(this, cx, cy, cx, cy);
152+
_columns[(cy*_width)+cx] = new Column(this, cx, cy, cx, cy);
153153
}
154154
}
155155

@@ -1020,6 +1020,11 @@ private void performTemporalPooling() {
10201020
for(Column col : _columns) {
10211021
for(int c=0; c<col.numCells(); ++c) {
10221022
Cell cell = col.getCell(c);
1023+
1024+
//process all segments on the cell to cache the activity for later
1025+
for(int s=0; s<cell.numSegments(); ++s)
1026+
cell.getSegment(s).processSegment();
1027+
10231028
for(int s=0; s<cell.numSegments(); ++s) {
10241029
Segment seg = cell.getSegment(s);
10251030
if(seg.isActive()) {
@@ -1218,6 +1223,11 @@ protected void compute() {
12181223
Column col = _region._columns[i];
12191224
for(int c=0; c<col.numCells(); ++c) {
12201225
Cell cell = col.getCell(c);
1226+
1227+
//process all segments on the cell to cache the activity for later
1228+
for(int s=0; s<cell.numSegments(); ++s)
1229+
cell.getSegment(s).processSegment();
1230+
12211231
for(int s=0; s<cell.numSegments(); ++s) {
12221232
Segment seg = cell.getSegment(s);
12231233
if(seg.isActive()) {

src/HTM/java/htm/Segment.java

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package htm;
22

33
import java.util.ArrayList;
4+
import java.util.Collections;
5+
import java.util.HashSet;
46
import java.util.List;
57
import java.util.Set;
68

@@ -23,13 +25,54 @@ public class Segment {
2325
private int _predictionSteps;
2426
private final float _segActiveThreshold;
2527

28+
private boolean _isActive;
29+
private boolean _wasActive;
30+
private Set<Synapse> _activeSynapses;
31+
private Set<Synapse> _prevSynapses;
32+
2633
/**
2734
* Initialize a new Segment with the specified segment activation threshold.
2835
*/
2936
Segment(int segActiveThreshold) {
3037
_synapses = new ArrayList<Synapse>();
3138
_isSequence = false;
3239
_segActiveThreshold = segActiveThreshold;
40+
_isActive = false;
41+
_wasActive = false;
42+
_activeSynapses = new HashSet<Synapse>();
43+
_prevSynapses = new HashSet<Synapse>();
44+
}
45+
46+
/**
47+
* Advance this segment to the next time step. The current state of this
48+
* segment (active, number of synapes) will be set as the previous state and
49+
* the current state will be reset to no cell activity by default until it
50+
* can be determined.
51+
*/
52+
void nextTimeStep() {
53+
_wasActive = _isActive;
54+
_isActive = false;
55+
_prevSynapses.clear();
56+
_prevSynapses.addAll(_activeSynapses);
57+
_activeSynapses.clear();
58+
}
59+
60+
/**
61+
* Process this segment for the current time step. Processing will determine
62+
* the set of active synapses on this segment for this time step. From there
63+
* we will determine if this segment is active if enough active synapses
64+
* are present. This information is then cached for the remainder of the
65+
* Region's processing for the time step. When a new time step occurs, the
66+
* Region will call nextTimeStep() on all cells/segments to cache the
67+
* information as what was previously active.
68+
*/
69+
public void processSegment() {
70+
_activeSynapses.clear();
71+
for(Synapse syn : _synapses) {
72+
if(syn.isActive())
73+
_activeSynapses.add(syn);
74+
}
75+
_isActive = _activeSynapses.size() >= _segActiveThreshold;
3376
}
3477

3578
/**
@@ -40,7 +83,7 @@ public class Segment {
4083
* (something that will eventually happen).
4184
* @param sequence true to make the segment a sequence segment, false not.
4285
*/
43-
public void setSequence(boolean sequence) {
86+
private void setSequence(boolean sequence) {
4487
_isSequence = sequence;
4588
}
4689

@@ -140,15 +183,11 @@ public void getSynapseCells(Set<Cell> cells) {
140183
}
141184

142185
/**
143-
* Populate the set with all the currently active (firing) synapses on
144-
* this segment.
145-
* @param connectedOnly: only consider if active if a synapse is connected.
186+
* Return the set of all the currently active (connected and firing)
187+
* synapses on this segment.
146188
*/
147-
public void getActiveSynapses(Set<Synapse> syns) {
148-
for(Synapse syn : _synapses) {
149-
if(syn.isActive())
150-
syns.add(syn);
151-
}
189+
public Set<Synapse> getActiveSynapses() {
190+
return Collections.unmodifiableSet(_activeSynapses);
152191
}
153192

154193
/**
@@ -166,6 +205,9 @@ public int getActiveSynapseCount() {
166205
* synapses which are currently connected.
167206
*/
168207
public int getActiveSynapseCount(boolean connectedOnly) {
208+
if(connectedOnly)
209+
return _activeSynapses.size();
210+
169211
int c=0;
170212
for(Synapse syn : _synapses) {
171213
if(syn.isActive(connectedOnly))
@@ -175,15 +217,11 @@ public int getActiveSynapseCount(boolean connectedOnly) {
175217
}
176218

177219
/**
178-
* Populate the set with all the previously active (firing) synapses on
220+
* Return the set of all the previously active (firing) synapses on
179221
* this segment.
180-
* @param connectedOnly: only consider if active if a synapse is connected.
181222
*/
182-
public void getPrevActiveSynapses(Set<Synapse> syns) {
183-
for(Synapse syn : _synapses) {
184-
if(syn.wasActive())
185-
syns.add(syn);
186-
}
223+
public Set<Synapse> getPrevActiveSynapses() {
224+
return Collections.unmodifiableSet(_prevSynapses);
187225
}
188226

189227
/**
@@ -201,6 +239,9 @@ public int getPrevActiveSynapseCount() {
201239
* synapses which are currently connected.
202240
*/
203241
public int getPrevActiveSynapseCount(boolean connectedOnly) {
242+
if(connectedOnly)
243+
return _prevSynapses.size();
244+
204245
int c=0;
205246
for(Synapse syn : _synapses) {
206247
if(syn.wasActive(connectedOnly))
@@ -264,20 +305,15 @@ public void decreasePermanences(Set<Synapse> activeSynapses) {
264305
* that are active due to active states at time t is greater than activationThreshold.
265306
*/
266307
public boolean isActive() {
267-
int c=0;
268-
for(Synapse syn : _synapses) {
269-
if(syn.isActive())
270-
++c;
271-
}
272-
return c >= _segActiveThreshold;
308+
return _isActive;
273309
}
274310

275311
/**
276312
* This routine returns true if the number of connected synapses on this segment
277313
* that were active due to active states at time t-1 is greater than activationThreshold.
278314
*/
279315
public boolean wasActive() {
280-
return getPrevActiveSynapseCount() >= _segActiveThreshold;
316+
return _wasActive;
281317
}
282318

283319
/**

0 commit comments

Comments
 (0)