@@ -105,5 +105,82 @@ void add1(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, struct c
105105 EXPECT_EQ (utils::Trim (out), utils::Trim (target_str));
106106}
107107
108+ TEST (CodeGenC, module_with_transform) {
109+ lang::Placeholder<float > A (" A" , {100 , 20 });
110+ lang::Placeholder<float > B (" B" , {100 , 20 });
111+
112+ lang::Buffer C_buf, D_buf;
113+
114+ // An inlined tensor, should not appear in final C code! It can be used by any times and expand its expression there.
115+ auto inlined0 = lang::Compute ({100 , 20 }, [&](Var i, Var j) { return A (i, j) * 2 .f + 1 .f ; });
116+
117+ auto C = lang::Compute (
118+ {100 , 20 }, [&](Var i, Var j) { return A (i, j) + B (i, j) + inlined0 (i, j); }, " C" );
119+ C->Bind (C_buf);
120+
121+ auto D = lang::Compute (
122+ {100 , 20 }, [&](Var i, Var j) { return C (i, j) * 2 .f * inlined0 (i, j); }, " D" );
123+ D->Bind (D_buf);
124+
125+ poly::Iterator i_outer, i_inner;
126+ std::tie (i_outer, i_inner) = C->stage ()->Split (poly::DefaultIterator (0 ), 4 );
127+
128+ D->stage ()->Tile (poly::DefaultIterator (0 ), poly::DefaultIterator (1 ), 4 , 16 );
129+
130+ Target target;
131+ target.arch = Target::Arch ::X86;
132+ target.bits = Target::Bit ::k32;
133+ target.os = Target::OS ::Linux;
134+ lang::Module module (" module1" , target);
135+
136+ auto funcs = lang::Lower (" add1" , {A, B, C, D});
137+
138+ ASSERT_EQ (funcs.size (), 1UL );
139+
140+ module .Append (funcs.front ());
141+ module .Append (C_buf);
142+
143+ std::stringstream ss;
144+ CodeGenC codegen (ss, target);
145+ codegen.Compile (module );
146+
147+ auto out = ss.str ();
148+ std::cout << " codegen C:" << std::endl << out << std::endl;
149+
150+ auto tgt = R"ROC(
151+ #ifndef _MODULE1_CINN_H_
152+ #define _MODULE1_CINN_H_
153+
154+ #include <cinn_runtime.h>
155+ #include <stdio.h>
156+
157+ cinn_buffer_t* C = cinn_buffer_t::new_(0/*target*/);
158+ void add1(const struct cinn_buffer_t *A, const struct cinn_buffer_t *B, const struct cinn_buffer_t *C, struct cinn_buffer_t *D)
159+ {
160+ cinn_buffer_malloc(D);
161+ for (int32_t i_outer = 0; (i_outer <= 24); i_outer += 1){
162+ for (int32_t i_inner = 0; (i_inner <= 3); i_inner += 1){
163+ for (int32_t j = 0; (j <= 19); j += 1){
164+ C[((((4 * i_outer) + i_inner) * 20) + j)] = ((A[((((4 * i_outer) + i_inner) * 20) + j)] + B[((((4 * i_outer) + i_inner) * 20) + j)]) + ((A[((((4 * i_outer) + i_inner) * 20) + j)] * 2) + 1));
165+ };
166+ };
167+ };
168+ for (int32_t i_outer = 0; (i_outer <= 24); i_outer += 1){
169+ for (int32_t i_inner = 0; (i_inner <= 3); i_inner += 1){
170+ for (int32_t j_outer = 0; (j_outer <= 1); j_outer += 1){
171+ for (int32_t j_inner = 0; (j_inner <= min(15, ((-16 * j_outer) + 19))); j_inner += 1){
172+ D[((((4 * i_outer) + i_inner) * 20) + ((16 * j_outer) + j_inner))] = ((C[((((4 * i_outer) + i_inner) * 20) + ((16 * j_outer) + j_inner))] * 2) * ((A[((((4 * i_outer) + i_inner) * 20) + ((16 * j_outer) + j_inner))] * 2) + 1));
173+ };
174+ };
175+ };
176+ };
177+ }
178+
179+ #endif // _MODULE1_CINN_H_
180+ )ROC" ;
181+
182+ ASSERT_EQ (utils::Trim (out), utils::Trim (tgt));
183+ }
184+
108185} // namespace backends
109186} // namespace cinn
0 commit comments