Skip to content

Commit 5f9762b

Browse files
committed
Fix an issue with update step
1 parent 8cfc858 commit 5f9762b

File tree

5 files changed

+49
-32
lines changed

5 files changed

+49
-32
lines changed

SixJoint/Learning.cpp

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "Arm.h"
1212

1313

14-
#define NUM_PERTURBATIONS 2 * NUM_POLICY_FEATURES
14+
#define NUM_PERTURBATIONS NUM_POLICY_FEATURES
1515

1616
extern ArmState startState;
1717
extern ArmState currentState;
@@ -23,25 +23,31 @@ extern ArmState targetState;
2323
#define P_C(parameters, index) parameters[4 * index + 2u]
2424
#define P_MV(parameters, index) parameters[4 * index + 3u]
2525

26-
float theta[NUM_POLICY_FEATURES] = {.5};
26+
float theta[NUM_POLICY_FEATURES] = {.5, .5, .5, .5,
27+
.5, .5, .5, .5,
28+
.5, .5, .5, .5,
29+
.5, .5, .5, .5,
30+
.5, .5, .5, .5,
31+
.5, .5, .5, .5};
2732
float actingTheta[NUM_POLICY_FEATURES] = {0};
2833
float perturbations[NUM_PERTURBATIONS][NUM_POLICY_FEATURES] = {0};
2934

3035
extern uint8_t jointRangeMin[];
3136
extern uint8_t jointRnageMax[];
3237

33-
float alpha = 0.001;
38+
float alpha = DEFAULT_ALPHA;
3439
float rl_gamma = 0.9999999;
35-
float I;
40+
41+
3642

3743
void logPolicyParameters() {
3844
logVector(theta, NUM_POLICY_FEATURES);
3945

4046
}
4147

4248
float evaluatePolicy() {
43-
ArmAction deltaToGoal;
4449
moveSmoothlyTo(startState);
50+
ArmAction deltaToGoal;
4551
actionBetweenStates(currentState, targetState, deltaToGoal);
4652
float equations[NUM_JOINTS][4];
4753
uint32_t maxIterations = 0;
@@ -66,14 +72,14 @@ float evaluatePolicy() {
6672
equations[j][2] = gamma;
6773

6874
const float velocityFactor = exp(P_MV(actingTheta, j)) + 1;
69-
const float iterations = 10.0 * percentMax * velocityFactor;
75+
const float iterations = 5.0 * percentMax * velocityFactor;
7076
equations[j][3] = ceil(iterations);
7177
maxIterations = max(maxIterations, (uint32_t)(iterations));
7278

7379
//D_LOG_V("EQ", equations[j], 4);
7480

7581
}
76-
maxIterations = min(maxIterations, 2000);
82+
maxIterations = min(maxIterations, 2000u);
7783
//D_LOG("iterations", maxIterations);
7884
resetPowerMeasurement();
7985

@@ -106,11 +112,11 @@ void iterate() {
106112
add(perturbations[i], theta, actingTheta, NUM_POLICY_FEATURES);
107113
const float evaluation = evaluatePolicy();
108114
evaluations[i] = evaluation;
109-
D_LOG("eval", evaluation);
110115
}
111-
float numUp[NUM_POLICY_FEATURES] = {0};
112-
float numDown[NUM_POLICY_FEATURES] = {0};
113-
float numNone[NUM_POLICY_FEATURES] = {0};
116+
D_LOG_V("evals", evaluations, NUM_PERTURBATIONS);
117+
uint8_t numUp[NUM_POLICY_FEATURES] = {0};
118+
uint8_t numDown[NUM_POLICY_FEATURES] = {0};
119+
uint8_t numNone[NUM_POLICY_FEATURES] = {0};
114120
float averageUp[NUM_POLICY_FEATURES] = {0};
115121
float averageDown[NUM_POLICY_FEATURES] = {0};
116122
float averageNone[NUM_POLICY_FEATURES] = {0};
@@ -121,28 +127,34 @@ void iterate() {
121127
numUp[j] += 1;
122128
averageUp[j] += evaluations[i];
123129
} else if (direction < 0.0) {
124-
numDown[j] += 1;
125-
averageDown[j] += evaluations[i];
130+
numDown[j] += 1;
131+
averageDown[j] += evaluations[i];
126132
} else {
127-
numNone[j] += 1;
128-
averageNone[j] += evaluations[i];
133+
numNone[j] += 1;
134+
averageNone[j] += evaluations[i];
129135
}
130136
}
131137
}
132138

133-
for (uint8_t i = 0; i < NUM_PERTURBATIONS; i++) {
134-
for (uint8_t j = 0; j < NUM_POLICY_FEATURES; j++) {
135-
if (numUp[j] > 0) {
136-
averageUp[j] /= numUp[j];
137-
}
138-
if (numDown[j] > 0) {
139-
averageDown[j] /= numDown[j];
140-
}
141-
if (numNone[j] > 0) {
142-
averageNone[j] /= numNone[j];
143-
}
139+
for (uint8_t j = 0; j < NUM_POLICY_FEATURES; j++) {
140+
if (numUp[j] > 0) {
141+
averageUp[j] /= (float)numUp[j];
142+
}
143+
if (numDown[j] > 0) {
144+
averageDown[j] /= (float)numDown[j];
145+
}
146+
if (numNone[j] > 0) {
147+
averageNone[j] /= (float)numNone[j];
144148
}
145149
}
150+
151+
D_LOG_V("nup", numUp, NUM_POLICY_FEATURES);
152+
D_LOG_V("ndown", numDown, NUM_POLICY_FEATURES);
153+
D_LOG_V("nnone", numNone, NUM_POLICY_FEATURES);
154+
155+
D_LOG_V("aup", averageUp, NUM_POLICY_FEATURES);
156+
D_LOG_V("adown", averageDown, NUM_POLICY_FEATURES);
157+
D_LOG_V("anone", averageNone, NUM_POLICY_FEATURES);
146158

147159
float delta[NUM_POLICY_FEATURES] = {0};
148160
for (uint8_t j = 0; j < NUM_POLICY_FEATURES; j++) {
@@ -152,10 +164,18 @@ void iterate() {
152164
delta[j] = averageUp[j] - averageDown[j];
153165
}
154166
}
167+
D_LOG_V("delta", delta, NUM_POLICY_FEATURES);
155168
norm(delta, NUM_POLICY_FEATURES);
156-
multiply(0.10, delta, NUM_POLICY_FEATURES);
169+
D_LOG_V("norm", delta, NUM_POLICY_FEATURES);
170+
multiply(alpha, delta, NUM_POLICY_FEATURES);
171+
D_LOG_V("step", delta, NUM_POLICY_FEATURES);
172+
173+
D_LOG_V("theta", theta, NUM_POLICY_FEATURES);
157174
add(theta, delta, NUM_POLICY_FEATURES);
175+
D_LOG_V("new theta", theta, NUM_POLICY_FEATURES);
176+
158177
copy(theta, actingTheta, NUM_POLICY_FEATURES);
178+
D_LOG_V("acting theta", actingTheta, NUM_POLICY_FEATURES);
159179

160180
}
161181

SixJoint/Learning.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#define NUM_POLICY_FEATURES NUM_JOINTS * 4
77
#define PERTURBATION_STEP 0.1
8+
#define DEFAULT_ALPHA 0.05
89

910
typedef struct ArmState {
1011
uint8_t jointAngles[NUM_JOINTS];

SixJoint/SixJoint.ino

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@
77
#include "Strings.h"
88
#include "Output.h"
99

10-
#define DEFAULT_ALPHA 0.1
1110
#define EVALUATION_MODE 1
1211
#define EVALUATION_SWITCH_POINT 50
1312
#define EVALUATION_MAX_STEPS 200
1413

1514
extern float alpha;
16-
extern float I;
1715

1816
extern const char spaceString[];
1917
extern const char cumulativeRewardString[];
@@ -89,7 +87,6 @@ void loop() {
8987
resetArmToRandomPosition();
9088
currentEpisode = 0;
9189
alpha = DEFAULT_ALPHA;
92-
I = 1.0;
9390
markTrialStart();
9491
}
9592

SixJoint/Task.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <Arduino.h>
22
#include "Learning.h"
33

4-
ArmState startState = {10, 60, 50,50,50,100};
4+
ArmState startState = {40, 60, 50,50,50,100};
55
ArmState targetState = {170,100,100,100,100,110};
66

77
void pickNewRandomTarget() {

SixJoint/Vector.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ float dot(const float first[], const float second[], const size_t length) {
7575
result += first[i] * second[i];
7676

7777
}
78-
D_LOG("m", result);
7978
return result;
8079
}
8180

0 commit comments

Comments
 (0)