Skip to content

Commit a047dca

Browse files
Creating jit sqlite writing from vars_to_write
1 parent 1483501 commit a047dca

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

parcels/compilation/codegenerator.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -952,15 +952,11 @@ def visit_FunctionDef(self, node):
952952
body += [c.Statement(f"type_coord particle_d{coord} = 0")]
953953
body += [stmt.ccode for stmt in node.body if not (hasattr(stmt, 'value') and type(stmt.value) is ast.Str)]
954954

955-
body += [c.Statement('sqlite3_prepare_v2(sql_db, "INSERT INTO particles VALUES (?, ?, ?, ?, ?, ?, ?, ?)", -1, &stmt, NULL)')]
956-
body += [c.Statement('sqlite3_bind_int(stmt, 1, particles->id[pnum])')]
957-
body += [c.Statement('sqlite3_bind_double(stmt, 2, particles->time[pnum])')]
958-
body += [c.Statement('sqlite3_bind_double(stmt, 3, particles->lon[pnum])')]
959-
body += [c.Statement('sqlite3_bind_double(stmt, 4, particles->lat[pnum])')]
960-
body += [c.Statement('sqlite3_bind_double(stmt, 5, particles->depth[pnum])')]
961-
body += [c.Statement('sqlite3_bind_double(stmt, 6, particles->u[pnum])')]
962-
body += [c.Statement('sqlite3_bind_double(stmt, 7, particles->v[pnum])')]
963-
body += [c.Statement('sqlite3_bind_double(stmt, 8, particles->p[pnum])')]
955+
dtype_map = {np.float32: 'double', np.float64: 'double', np.int32: 'int', np.int64: 'int'}
956+
query = ', '.join('?' * len(self.fieldset.particlefile.vars_to_write))
957+
body += [c.Statement(f'sqlite3_prepare_v2(sql_db, "INSERT INTO particles VALUES ({query})", -1, &stmt, NULL)')]
958+
for i, var in enumerate(self.fieldset.particlefile.vars_to_write.items()):
959+
body += [c.Statement(f"sqlite3_bind_{dtype_map[var[1]]}(stmt, {i+1}, particles->{var[0]}[pnum])")]
964960
body += [c.Statement('sqlite3_step(stmt)')]
965961
body += [c.Statement('sqlite3_finalize(stmt)')]
966962

@@ -1123,15 +1119,11 @@ def visit_FunctionDef(self, node):
11231119
body += [c.Statement(f"type_coord particle_d{coord} = 0")]
11241120
body += [stmt.ccode for stmt in node.body if not (hasattr(stmt, 'value') and type(stmt.value) is ast.Str)]
11251121

1126-
body += [c.Statement('sqlite3_prepare_v2(sql_db, "INSERT INTO particles VALUES (?, ?, ?, ?, ?, ?, ?, ?)", -1, &stmt, NULL)')]
1127-
body += [c.Statement('sqlite3_bind_int(stmt, 1, particle->id)')]
1128-
body += [c.Statement('sqlite3_bind_double(stmt, 2, particle->time)')]
1129-
body += [c.Statement('sqlite3_bind_double(stmt, 3, particle->lon)')]
1130-
body += [c.Statement('sqlite3_bind_double(stmt, 4, particle->lat)')]
1131-
body += [c.Statement('sqlite3_bind_double(stmt, 5, particle->depth)')]
1132-
body += [c.Statement('sqlite3_bind_double(stmt, 6, particle->u)')]
1133-
body += [c.Statement('sqlite3_bind_double(stmt, 7, particle->v)')]
1134-
body += [c.Statement('sqlite3_bind_double(stmt, 8, particle->p)')]
1122+
dtype_map = {np.float32: 'double', np.float64: 'double', np.int32: 'int', np.int64: 'int'}
1123+
query = ', '.join('?' * len(self.fieldset.particlefile.vars_to_write))
1124+
body += [c.Statement(f'sqlite3_prepare_v2(sql_db, "INSERT INTO particles VALUES ({query})", -1, &stmt, NULL)')]
1125+
for i, var in enumerate(self.fieldset.particlefile.vars_to_write.items()):
1126+
body += [c.Statement(f"sqlite3_bind_{dtype_map[var[1]]}(stmt, {i+1}, particle->{var[0]})")]
11351127
body += [c.Statement('sqlite3_step(stmt)')]
11361128
body += [c.Statement('sqlite3_finalize(stmt)')]
11371129

0 commit comments

Comments
 (0)