Skip to content

Add support for array values in WHERE clause of CQL queries #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions spec/core/query_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,34 @@ describe CQL::Query do

select_query.should eq({output.strip, [1000]})
end

it "handles array values in WHERE clause" do
select_query = Northwind.query
.from(:customers)
.select(:name, :city)
.where(id: [1, 2, 3])
.to_sql

output = <<-SQL
SELECT customers.name, customers.city FROM customers WHERE customers.id IN (?, ?, ?)
SQL

select_query.should eq({output.strip, [1, 2, 3]})
end

it "handles multiple array conditions in WHERE clause" do
select_query = Northwind.query
.from(:customers)
.select(:name, :city)
.where(id: [1, 2, 3], city: ["Rome", "Vienna"])
.to_sql

output = <<-SQL
SELECT customers.name, customers.city FROM customers WHERE (customers.id IN (?, ?, ?)) AND (customers.city IN (?, ?))
SQL

select_query.should eq({output.strip, [1, 2, 3, "Rome", "Vienna"]})
end
end

describe "ORDER BY, LIMIT, and other clauses" do
Expand Down
16 changes: 10 additions & 6 deletions src/column.cr
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
require "./base_column"
require "db"
require "./table"

module CQL
# A column in a table
# This class represents a column in a table
Expand All @@ -19,7 +23,7 @@ module CQL
# :nodoc:
property name : Symbol
# :nodoc:
property type : Any
property type : T.class | Array(T.class)
# :nodoc:
getter? null : Bool = false
# :nodoc:
Expand Down Expand Up @@ -55,7 +59,7 @@ module CQL
# ```
def initialize(
@name : Symbol,
@type : T.class,
@type : T.class | Array(T.class),
@as_name : String? = nil,
@null : Bool = false,
@default : DB::Any = nil,
Expand Down Expand Up @@ -88,10 +92,10 @@ module CQL
# column = CQL::Column.new(:name, String)
# column.validate!("John")
# ```
def validate!(value)
return if value.class == JSON::Any && value.is_a?(String)
return if value.class == type
raise Error.new "Expected column `#{name}` to be #{type}, but got #{value.class}"
def validate!(value : T | Array(T)) forall T
return if value.class == JSON::Any
return if value.class == T
Copy link
Preview

Copilot AI May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated validate! method only checks the class of the input value and does not iterate over array elements when an Array is provided, which may cause valid array inputs to be rejected. Consider iterating through the array and validating each element against the expected type.

Suggested change
return if value.class == T
if value.is_a?(Array)
value.each do |element|
unless element.is_a?(T)
raise Error.new "Expected all elements of column `#{name}` to be #{type}, but got #{element.class}"
end
end
return
end
return if value.is_a?(T)

Copilot uses AI. Check for mistakes.

raise Error.new "Expected column `#{name}` to be #{type} or Array(#{type}), but got #{value.class}"
end
end
end
2 changes: 1 addition & 1 deletion src/expression/expressions.cr
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ module Expression
getter column : Column
getter values : Array(DB::Any)

def initialize(@column : Column, values : Array(T)) forall T
def initialize(@column : Column, values)
@values = values.map { |v| v.as(DB::Any) }
end

Expand Down
51 changes: 44 additions & 7 deletions src/query.cr
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,9 @@ module CQL
end

# Accept Hash(Symbol, DB::Any) for backward compatibility
def where(hash : Hash(String | Symbol, DB::Any))
def where(hash : Hash(String | Symbol, DB::Any | Array(DB::Any)))
# Convert Symbol keys to String
string_keyed_hash = hash.transform_keys(&.to_s)
new_condition = build_condition_from_hash(string_keyed_hash)
new_condition = build_condition_from_hash(hash)
merge_where_condition(new_condition)
self
end
Expand All @@ -458,7 +457,7 @@ module CQL
end

def where(**fields)
new_condition = build_condition_from_hash(fields.to_h)
new_condition = build_condition_from_hash(fields)
merge_where_condition(new_condition)
self
end
Expand Down Expand Up @@ -726,12 +725,24 @@ module CQL
end

# Expects hash with String keys now
private def build_condition_from_hash(hash : Hash(Symbol | String, DB::Any))
private def build_condition_from_hash(hash : Hash(Symbol | String, T | Array(T))) forall T
Copy link
Preview

Copilot AI May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] There are two overloads of build_condition_from_hash with similar logic; consolidating them could reduce duplication and improve code clarity.

Copilot uses AI. Check for mistakes.

condition = nil
hash.each_with_index do |(k, v), index|
# k is String (qualified or unqualified column name)
expr = get_expression(k, v.as(T)) # Handles String key
condition = index == 0 ? expr : Expression::And.new(condition.not_nil!, expr)
end
condition.not_nil!
end

private def build_condition_from_hash(fields)
condition = nil
index = 0
fields.each do |k, v|
# k is String (qualified or unqualified column name)
expr = get_expression(k, v) # Handles String key
condition = index == 0 ? expr : Expression::And.new(condition.not_nil!, expr)
index += 1
end
condition.not_nil!
end
Expand All @@ -746,14 +757,40 @@ module CQL
end

# Handles String field (qualified or unqualified)
private def get_expression(field : Symbol | String, value)
private def get_expression(field : Symbol | String, value : T | Array(T)) forall T
# find_column handles String field, finds BaseColumn
column = find_column(field)
# find_alias_for_table returns String alias
col_alias_str = find_alias_for_table(column.table.not_nil!)
column.validate!(value.as(T))

# Create Expression::Column with String alias
col_expr = Expression::Column.new(column, alias_name: col_alias_str)

# Handle array values for IN conditions
if value.is_a?(Array)
Expression::InCondition.new(col_expr, value)
else
Expression::Compare.new(col_expr, "=", value.as(DB::Any))
end
end

private def get_expression(field : Symbol | String, value : String | Array(String))
# find_column handles String field, finds BaseColumn
column = find_column(field)
# find_alias_for_table returns String alias
col_alias_str = find_alias_for_table(column.table.not_nil!)
column.validate!(value)

# Create Expression::Column with String alias
Expression::Compare.new(Expression::Column.new(column, alias_name: col_alias_str), "=", value)
col_expr = Expression::Column.new(column, alias_name: col_alias_str)

# Handle array values for IN conditions
if value.is_a?(Array)
Expression::InCondition.new(col_expr, value)
else
Expression::Compare.new(col_expr, "=", value.as(DB::Any))
end
end

private def build_group_by
Expand Down
Loading