Skip to content

Commit a01977d

Browse files
authored
Merge pull request #456 from Shopify/pb-use-arel-for-conditions
Use Arel instead of String for AR Enumerator conditionals
2 parents cccb61d + f2392df commit a01977d

File tree

5 files changed

+59
-19
lines changed

5 files changed

+59
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
when generating position for cursor based on `:id` column (Rails 7.1 and above, where composite
1515
primary models are now supported). This ensures we grab the value of the id column, rather than a
1616
potentially composite primary key value.
17+
- [456](https://github.com/Shopify/job-iteration/pull/431) - Use Arel to generate SQL that's type compatible for the
18+
cursor pagination conditionals in ActiveRecord cursor. Previously, the cursor would coerce numeric ids to a string value
19+
(e.g.: `... AND id > '1'`)
1720

1821
## v1.4.1 (Sep 5, 2023)
1922

lib/job-iteration/active_record_cursor.rb

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,8 @@ def initialize
1818
end
1919
end
2020

21-
def initialize(relation, columns = nil, position = nil)
22-
@columns = if columns
23-
Array(columns)
24-
else
25-
Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" }
26-
end
21+
def initialize(relation, columns, position = nil)
22+
@columns = columns
2723
self.position = Array.wrap(position)
2824
raise ArgumentError, "Must specify at least one column" if columns.empty?
2925
if relation.joins_values.present? && !@columns.all? { |column| column.to_s.include?(".") }
@@ -34,7 +30,7 @@ def initialize(relation, columns = nil, position = nil)
3430
raise ConditionNotSupportedError
3531
end
3632

37-
@base_relation = relation.reorder(@columns.join(","))
33+
@base_relation = relation.reorder(*@columns)
3834
@reached_end = false
3935
end
4036

@@ -54,12 +50,10 @@ def position=(position)
5450

5551
def update_from_record(record)
5652
self.position = @columns.map do |column|
57-
method = column.to_s.split(".").last
58-
59-
if ActiveRecord.version >= Gem::Version.new("7.1.0.alpha") && method == "id"
53+
if ActiveRecord.version >= Gem::Version.new("7.1.0.alpha") && column.name == "id"
6054
record.id_value
6155
else
62-
record.send(method.to_sym)
56+
record.send(column.name)
6357
end
6458
end
6559
end
@@ -89,14 +83,14 @@ def conditions
8983
i = @position.size - 1
9084
column = @columns[i]
9185
conditions = if @columns.size == @position.size
92-
"#{column} > ?"
86+
column.gt(@position[i])
9387
else
94-
"#{column} >= ?"
88+
column.gteq(@position[i])
9589
end
9690
while i > 0
9791
i -= 1
9892
column = @columns[i]
99-
conditions = "#{column} > ? OR (#{column} = ? AND (#{conditions}))"
93+
conditions = column.gt(@position[i]).or(column.eq(@position[i]).and(conditions))
10094
end
10195
ret = @position.reduce([conditions]) { |params, value| params << value << value }
10296
ret.pop

lib/job-iteration/active_record_enumerator.rb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ def initialize(relation, columns: nil, batch_size: 100, cursor: nil)
1111
@relation = relation
1212
@batch_size = batch_size
1313
@columns = if columns
14-
Array(columns)
14+
Array(columns).map { |col| relation.arel_table[col.to_sym] }
1515
else
16-
Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" }
16+
Array(relation.primary_key).map { |pk| relation.arel_table[pk.to_sym] }
1717
end
1818
@cursor = cursor
1919
end
@@ -45,7 +45,7 @@ def size
4545

4646
def cursor_value(record)
4747
positions = @columns.map do |column|
48-
attribute_name = column.to_s.split(".").last
48+
attribute_name = column.name.to_sym
4949
column_value(record, attribute_name)
5050
end
5151
return positions.first if positions.size == 1
@@ -58,8 +58,8 @@ def finder_cursor
5858
end
5959

6060
def column_value(record, attribute)
61-
value = record.read_attribute(attribute.to_sym)
62-
case record.class.columns_hash.fetch(attribute).type
61+
value = record.read_attribute(attribute)
62+
case record.class.columns_hash.fetch(attribute.to_s).type
6363
when :datetime
6464
value.strftime(SQL_DATETIME_WITH_NSEC)
6565
else

test/test_helper.rb

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,40 @@ def assert_logged(message)
106106
end
107107
end
108108

109+
module ActiveRecordHelpers
110+
def assert_sql(*patterns_to_match, &block)
111+
captured_queries = []
112+
assert_nothing_raised do
113+
ActiveSupport::Notifications.subscribed(
114+
->(_name, _start_time, _end_time, _subscriber_id, payload) { captured_queries << payload[:sql] },
115+
"sql.active_record",
116+
&block
117+
)
118+
end
119+
120+
failed_patterns = []
121+
patterns_to_match.each do |pattern|
122+
failed_check = captured_queries.none? do |sql|
123+
case pattern
124+
when Regexp
125+
sql.match?(pattern)
126+
when String
127+
sql == pattern
128+
else
129+
raise ArgumentError, "#assert_sql encountered an unknown matcher #{pattern.inspect}"
130+
end
131+
end
132+
failed_patterns << pattern if failed_check
133+
end
134+
queries = captured_queries.empty? ? "" : "\nQueries:\n #{captured_queries.join("\n ")}"
135+
assert_predicate(
136+
failed_patterns,
137+
:empty?,
138+
"Query pattern(s) #{failed_patterns.map(&:inspect).join(", ")} not found.#{queries}",
139+
)
140+
end
141+
end
142+
109143
JobIteration.logger = Logger.new(IO::NULL)
110144
ActiveJob::Base.logger = Logger.new(IO::NULL)
111145

test/unit/active_record_enumerator_test.rb

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
module JobIteration
66
class ActiveRecordEnumeratorTest < IterationUnitTest
7+
include ActiveRecordHelpers
8+
79
SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%N"
810
test "#records yields every record with their cursor position" do
911
enum = build_enumerator.records
@@ -133,6 +135,13 @@ class ActiveRecordEnumeratorTest < IterationUnitTest
133135
end
134136
end
135137

138+
test "enumerator paginates using integer conditionals for primary key when no columns are defined" do
139+
enum = build_enumerator(relation: Product.all, batch_size: 1).records
140+
assert_sql(/`products`\.`id` > 1/) do
141+
enum.take(2)
142+
end
143+
end
144+
136145
private
137146

138147
def build_enumerator(relation: Product.all, batch_size: 2, columns: nil, cursor: nil)

0 commit comments

Comments
 (0)