Skip to content

Commit 3ca282d

Browse files
authored
add differential pointer return c test (#29)
* add differential pointer return c test * change all tests to do approx fp comparisons, will be needed for more complex tests
1 parent cd8a395 commit 3ca282d

13 files changed

+126
-105
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <math.h>
4+
#include <assert.h>
5+
6+
#include "test_utils.h"
7+
8+
#define __builtin_autodiff __enzyme_autodiff
9+
double __enzyme_autodiff(void*, ...);
10+
11+
double f_read(double* x) {
12+
double product = (*x) * (*x);
13+
return product;
14+
}
15+
16+
double* g_write(double* x, double product) {
17+
*x = (*x) * product;
18+
return x;
19+
}
20+
21+
double h_read(double* x) {
22+
return *x;
23+
}
24+
25+
double readwriteread_helper(double* x) {
26+
double product = f_read(x);
27+
x = g_write(x, product);
28+
double ret = h_read(x);
29+
return ret;
30+
}
31+
32+
void readwriteread(double*__restrict x, double*__restrict ret) {
33+
*ret = readwriteread_helper(x);
34+
}
35+
36+
int main(int argc, char** argv) {
37+
double ret = 0;
38+
double dret = 1.0;
39+
double* x = (double*) malloc(sizeof(double));
40+
double* dx = (double*) malloc(sizeof(double));
41+
*x = 2.0;
42+
*dx = 0.0;
43+
44+
__builtin_autodiff(readwriteread, x, dx, &ret, &dret);
45+
46+
47+
printf("dx is %f ret is %f\n", *dx, ret);
48+
assert(approx_fp_equality_float(*dx, 3*2.0*2.0, 1e-10));
49+
return 0;
50+
}

enzyme/functional_tests_c/insertsort_sum.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <math.h>
44
#include <assert.h>
55

6+
#include "test_utils.h"
7+
68
#define __builtin_autodiff __enzyme_autodiff
79

810
double __enzyme_autodiff(void*, ...);
@@ -67,9 +69,9 @@ int main(int argc, char** argv) {
6769
for (int i = 0; i < N; i++) {
6870
printf("Diffe for index %d is %f\n", i, d_array[i]);
6971
if (i%2 == 0) {
70-
assert(d_array[i] == 0.0);
72+
assert(approx_fp_equality_float(d_array[i], 0.0, 1e-10));
7173
} else {
72-
assert(d_array[i] == 1.0);
74+
assert(approx_fp_equality_float(d_array[i],1.0,1e-10));
7375
}
7476
}
7577
return 0;

enzyme/functional_tests_c/insertsort_sum_alt.c

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,10 @@
33
#include <math.h>
44
#include <assert.h>
55

6+
#include "test_utils.h"
7+
68
#define __builtin_autodiff __enzyme_autodiff
79
double __enzyme_autodiff(void*, ...);
8-
//float man_max(float* a, float* b) {
9-
// if (*a > *b) {
10-
// return *a;
11-
// } else {
12-
// return *b;
13-
// }
14-
//}
15-
1610

1711
// size of array
1812
float* unsorted_array_init(int N) {
@@ -37,36 +31,19 @@ void insertion_sort_inner(float* array, int i) {
3731
// sums the first half of a sorted array.
3832
void insertsort_sum (float*__restrict array, int N, float*__restrict ret) {
3933
float sum = 0;
40-
//qsort(array, N, sizeof(float), cmp);
4134

4235
for (int i = 1; i < N; i++) {
4336
insertion_sort_inner(array, i);
4437
}
4538

46-
4739
for (int i = 0; i < N/2; i++) {
4840
//printf("Val: %f\n", array[i]);
4941
sum += array[i];
5042
}
5143
*ret = sum;
5244
}
5345

54-
55-
56-
5746
int main(int argc, char** argv) {
58-
59-
60-
61-
float a = 2.0;
62-
float b = 3.0;
63-
64-
65-
66-
float da = 0;//(float*) malloc(sizeof(float));
67-
float db = 0;//(float*) malloc(sizeof(float));
68-
69-
7047
float ret = 0;
7148
float dret = 1.0;
7249

@@ -85,15 +62,11 @@ int main(int argc, char** argv) {
8562
for (int i = 0; i < N; i++) {
8663
printf("Diffe for index %d is %f\n", i, d_array[i]);
8764
if (i%2 == 0) {
88-
assert(d_array[i] == 0.0);
65+
assert(approx_fp_equality_float(d_array[i], 0.0, 1e-10));
8966
} else {
90-
assert(d_array[i] == 1.0);
67+
assert(approx_fp_equality_float(d_array[i], 1.0, 1e-10));
9168
}
9269
}
9370

94-
//assert(da == 100*1.0f);
95-
//assert(db == 100*1.0f);
96-
97-
//printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db);
9871
return 0;
9972
}

enzyme/functional_tests_c/loops.c

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
#include <math.h>
33
#include <assert.h>
44

5-
#define __builtin_autodiff __enzyme_autodiff
5+
#include "test_utils.h"
66

7+
#define __builtin_autodiff __enzyme_autodiff
78

89
double __enzyme_autodiff(void*, ...);
910

@@ -17,14 +18,6 @@ double __enzyme_autodiff(void*, ...);
1718
void compute_loops(float* a, float* b, float* ret) {
1819
double sum0 = 0.0;
1920
for (int i = 0; i < 100; i++) {
20-
//double sum1 = 0.0;
21-
//for (int j = 0; j < 100; j++) {
22-
// //double sum2 = 0.0;
23-
// //for (int k = 0; k < 100; k++) {
24-
// // sum2 += *a+*b;
25-
// //}
26-
// sum1 += *a+*b;
27-
//}
2821
sum0 += *a + *b;
2922
}
3023
*ret = sum0;
@@ -48,13 +41,11 @@ int main(int argc, char** argv) {
4841
float ret = 0;
4942
float dret = 1.0;
5043

51-
//compute_loops(&a, &b, &ret);
52-
5344
__builtin_autodiff(compute_loops, &a, &da, &b, &db, &ret, &dret);
5445

5546

56-
assert(da == 100*1.0f);
57-
assert(db == 100*1.0f);
47+
assert(approx_fp_equality_float(da, 100*1.0f, 1e-10));
48+
assert(approx_fp_equality_float(db, 100*1.0f, 1e-10));
5849

5950
printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db);
6051
return 0;

enzyme/functional_tests_c/loopsdouble.c

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,38 @@
11
#include <stdio.h>
22
#include <math.h>
33
#include <assert.h>
4+
5+
#include "test_utils.h"
6+
47
#define __builtin_autodiff __enzyme_autodiff
58
double __enzyme_autodiff(void*, ...);
6-
//float man_max(float* a, float* b) {
7-
// if (*a > *b) {
8-
// return *a;
9-
// } else {
10-
// return *b;
11-
// }
12-
//}
9+
1310
void compute_loops(float* a, float* b, float* ret) {
1411
double sum0 = 0.0;
1512
for (int i = 0; i < 100; i++) {
1613
double sum1 = 0.0;
1714
for (int j = 0; j < 100; j++) {
18-
//double sum2 = 0.0;
19-
//for (int k = 0; k < 100; k++) {
20-
// sum2 += *a+*b;
21-
//}
2215
sum1 += *a+*b;
2316
}
2417
sum0 += sum1;
2518
}
2619
*ret = sum0;
2720
}
2821

29-
30-
3122
int main(int argc, char** argv) {
32-
33-
34-
3523
float a = 2.0;
3624
float b = 3.0;
3725

38-
39-
40-
float da = 0;//(float*) malloc(sizeof(float));
41-
float db = 0;//(float*) malloc(sizeof(float));
42-
26+
float da = 0;
27+
float db = 0;
4328

4429
float ret = 0;
4530
float dret = 1.0;
4631

47-
//compute_loops(&a, &b, &ret);
48-
4932
__builtin_autodiff(compute_loops, &a, &da, &b, &db, &ret, &dret);
5033

51-
52-
assert(da == 100*100*1.0f);
53-
assert(db == 100*100*1.0f);
34+
assert(approx_fp_equality_float(da, 100*100*1.0f, 1e-10));
35+
assert(approx_fp_equality_float(db, 100*100*1.0f, 1e-10));
5436

5537
printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db);
5638
return 0;

enzyme/functional_tests_c/loopstriple.c

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include <math.h>
33
#include <assert.h>
44

5+
#include "test_utils.h"
6+
57
#define __builtin_autodiff __enzyme_autodiff
68
double __enzyme_autodiff(void*, ...);
79
void compute_loops(float* a, float* b, float* ret) {
@@ -20,31 +22,21 @@ void compute_loops(float* a, float* b, float* ret) {
2022
*ret = sum0;
2123
}
2224

23-
24-
2525
int main(int argc, char** argv) {
2626

27-
28-
2927
float a = 2.0;
3028
float b = 3.0;
3129

32-
33-
34-
float da = 0;//(float*) malloc(sizeof(float));
35-
float db = 0;//(float*) malloc(sizeof(float));
36-
30+
float da = 0;
31+
float db = 0;
3732

3833
float ret = 0;
3934
float dret = 1.0;
4035

41-
//compute_loops(&a, &b, &ret);
42-
4336
__builtin_autodiff(compute_loops, &a, &da, &b, &db, &ret, &dret);
4437

45-
46-
assert(da == 100*100*100*1.0f);
47-
assert(db == 100*100*100*1.0f);
38+
assert(approx_fp_equality_float(da, 100*100*100*1.0f, 1e-10));
39+
assert(approx_fp_equality_float(db, 100*100*100*1.0f, 1e-10));
4840

4941
printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db);
5042
return 0;

enzyme/functional_tests_c/readwriteread.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include <stdlib.h>
33
#include <math.h>
44
#include <assert.h>
5+
6+
#include "test_utils.h"
7+
58
#define __builtin_autodiff __enzyme_autodiff
69
double __enzyme_autodiff(void*, ...);
710

@@ -38,9 +41,8 @@ int main(int argc, char** argv) {
3841
*dx = 0.0;
3942

4043
__builtin_autodiff(readwriteread, x, dx, &ret, &dret);
41-
4244

4345
printf("dx is %f ret is %f\n", *dx, ret);
44-
assert(*dx == 3*2.0*2.0);
46+
assert(approx_fp_equality_double(*dx, 3*2.0*2.0, 1e-10));
4547
return 0;
4648
}

enzyme/functional_tests_c/recurse.c

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#include <stdio.h>
22
#include <math.h>
33
#include <assert.h>
4+
5+
#include "test_utils.h"
6+
47
#define __builtin_autodiff __enzyme_autodiff
58
double __enzyme_autodiff(void*, ...);
6-
int counter = 0;
9+
710
double recurse_max_helper(float* a, float* b, int N) {
811
if (N <= 0) {
912
return *a + *b;
@@ -14,36 +17,25 @@ void recurse_max(float* a, float* b, float* ret, int N) {
1417
*ret = recurse_max_helper(a,b,N);
1518
}
1619

17-
18-
1920
int main(int argc, char** argv) {
20-
21-
22-
2321
float a = 2.0;
2422
float b = 3.0;
2523

26-
27-
28-
float da = 0;//(float*) malloc(sizeof(float));
29-
float db = 0;//(float*) malloc(sizeof(float));
30-
24+
float da = 0;
25+
float db = 0;
3126

3227
float ret = 0;
3328
float dret = 2.0;
3429

35-
recurse_max(&a, &b, &ret, 20);
30+
//recurse_max(&a, &b, &ret, 20);
3631

3732
int N = 20;
3833
int dN = 0;
3934

4035
__builtin_autodiff(recurse_max, &a, &da, &b, &db, &ret, &dret, 20);
4136

42-
43-
assert(da == 17711.0*2);
44-
assert(db == 17711.0*2);
45-
46-
37+
assert(approx_fp_equality_float(da, 17711.0*2, 1e-10));
38+
assert(approx_fp_equality_float(db, 17711.0*2, 1e-10));
4739

4840
printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db);
4941
return 0;
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include <stdlib.h>
2+
#include <stdbool.h>
3+
#include <math.h>
4+
5+
static bool approx_fp_equality_float(float f1, float f2, double threshold) {
6+
if (fabs(f1-f2) > threshold) return false;
7+
return true;
8+
}
9+
10+
static bool approx_fp_equality_double(double f1, double f2, double threshold) {
11+
if (fabs(f1-f2) > threshold) return false;
12+
return true;
13+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
; RUN: cd %desired_wd
2+
; RUN: make clean-differential_pointer_return-enzyme0 ENZYME_PLUGIN=%loadEnzyme
3+
; RUN: make build/differential_pointer_return-enzyme0 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
4+
; RUN: build/differential_pointer_return-enzyme0
5+
; RUN: make clean-differential_pointer_return-enzyme0 ENZYME_PLUGIN=%loadEnzyme
6+

0 commit comments

Comments
 (0)