@@ -22,36 +22,53 @@ use async_trait::async_trait;
2222use datafusion_catalog:: Session ;
2323use datafusion_catalog:: TableFunctionImpl ;
2424use datafusion_catalog:: TableProvider ;
25- use datafusion_common:: { not_impl_err , plan_err, Result , ScalarValue } ;
25+ use datafusion_common:: { plan_err, Result , ScalarValue } ;
2626use datafusion_expr:: { Expr , TableType } ;
2727use datafusion_physical_plan:: memory:: { LazyBatchGenerator , LazyMemoryExec } ;
2828use datafusion_physical_plan:: ExecutionPlan ;
2929use parking_lot:: RwLock ;
3030use std:: fmt;
3131use std:: sync:: Arc ;
3232
33- /// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive)
33+ /// Indicates the arguments used for generating a series.
34+ #[ derive( Debug , Clone ) ]
35+ enum GenSeriesArgs {
36+ /// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
37+ ContainsNull ,
38+ /// AllNotNullArgs holds the start, end, and step values for generating the series when all arguments are not null.
39+ AllNotNullArgs { start : i64 , end : i64 , step : i64 } ,
40+ }
41+
42+ /// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step
3443#[ derive( Debug , Clone ) ]
3544struct GenerateSeriesTable {
3645 schema : SchemaRef ,
37- // None if input is Null
38- start : Option < i64 > ,
39- // None if input is Null
40- end : Option < i64 > ,
46+ args : GenSeriesArgs ,
4147}
4248
43- /// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive)
49+ /// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step
4450#[ derive( Debug , Clone ) ]
4551struct GenerateSeriesState {
4652 schema : SchemaRef ,
4753 start : i64 , // Kept for display
4854 end : i64 ,
55+ step : i64 ,
4956 batch_size : usize ,
5057
5158 /// Tracks current position when generating table
5259 current : i64 ,
5360}
5461
62+ impl GenerateSeriesState {
63+ fn reach_end ( & self , val : i64 ) -> bool {
64+ if self . step > 0 {
65+ return val > self . end ;
66+ }
67+
68+ val < self . end
69+ }
70+ }
71+
5572/// Detail to display for 'Explain' plan
5673impl fmt:: Display for GenerateSeriesState {
5774 fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
@@ -65,19 +82,19 @@ impl fmt::Display for GenerateSeriesState {
6582
6683impl LazyBatchGenerator for GenerateSeriesState {
6784 fn generate_next_batch ( & mut self ) -> Result < Option < RecordBatch > > {
68- // Check if we've reached the end
69- if self . current > self . end {
85+ let mut buf = Vec :: with_capacity ( self . batch_size ) ;
86+ while buf. len ( ) < self . batch_size && !self . reach_end ( self . current ) {
87+ buf. push ( self . current ) ;
88+ self . current += self . step ;
89+ }
90+ let array = Int64Array :: from ( buf) ;
91+
92+ if array. is_empty ( ) {
7093 return Ok ( None ) ;
7194 }
7295
73- // Construct batch
74- let batch_end = ( self . current + self . batch_size as i64 - 1 ) . min ( self . end ) ;
75- let array = Int64Array :: from_iter_values ( self . current ..=batch_end) ;
7696 let batch = RecordBatch :: try_new ( self . schema . clone ( ) , vec ! [ Arc :: new( array) ] ) ?;
7797
78- // Update current position for next batch
79- self . current = batch_end + 1 ;
80-
8198 Ok ( Some ( batch) )
8299 }
83100}
@@ -104,77 +121,90 @@ impl TableProvider for GenerateSeriesTable {
104121 _limit : Option < usize > ,
105122 ) -> Result < Arc < dyn ExecutionPlan > > {
106123 let batch_size = state. config_options ( ) . execution . batch_size ;
107- match ( self . start , self . end ) {
108- ( Some ( start) , Some ( end) ) => {
109- if start > end {
110- return plan_err ! (
111- "End value must be greater than or equal to start value"
112- ) ;
113- }
114-
115- Ok ( Arc :: new ( LazyMemoryExec :: try_new (
116- self . schema . clone ( ) ,
117- vec ! [ Arc :: new( RwLock :: new( GenerateSeriesState {
118- schema: self . schema. clone( ) ,
119- start,
120- end,
121- current: start,
122- batch_size,
123- } ) ) ] ,
124- ) ?) )
125- }
126- _ => {
127- // Either start or end is None, return a generator that outputs 0 rows
128- Ok ( Arc :: new ( LazyMemoryExec :: try_new (
129- self . schema . clone ( ) ,
130- vec ! [ Arc :: new( RwLock :: new( GenerateSeriesState {
131- schema: self . schema. clone( ) ,
132- start: 0 ,
133- end: 0 ,
134- current: 1 ,
135- batch_size,
136- } ) ) ] ,
137- ) ?) )
138- }
139- }
124+
125+ let state = match self . args {
126+ // if args have null, then return 0 row
127+ GenSeriesArgs :: ContainsNull => GenerateSeriesState {
128+ schema : self . schema . clone ( ) ,
129+ start : 0 ,
130+ end : 0 ,
131+ step : 1 ,
132+ current : 1 ,
133+ batch_size,
134+ } ,
135+ GenSeriesArgs :: AllNotNullArgs { start, end, step } => GenerateSeriesState {
136+ schema : self . schema . clone ( ) ,
137+ start,
138+ end,
139+ step,
140+ current : start,
141+ batch_size,
142+ } ,
143+ } ;
144+
145+ Ok ( Arc :: new ( LazyMemoryExec :: try_new (
146+ self . schema . clone ( ) ,
147+ vec ! [ Arc :: new( RwLock :: new( state) ) ] ,
148+ ) ?) )
140149 }
141150}
142151
143152#[ derive( Debug ) ]
144153pub struct GenerateSeriesFunc { }
145154
146155impl TableFunctionImpl for GenerateSeriesFunc {
147- // Check input `exprs` type and number. Input validity check (e.g. start <= end)
148- // will be performed in `TableProvider::scan`
149156 fn call ( & self , exprs : & [ Expr ] ) -> Result < Arc < dyn TableProvider > > {
150- // TODO: support 1 or 3 arguments following DuckDB:
151- // <https://duckdb.org/docs/sql/functions/list#generate_series>
152- if exprs. len ( ) == 3 || exprs. len ( ) == 1 {
153- return not_impl_err ! ( "generate_series does not support 1 or 3 arguments" ) ;
157+ if exprs. is_empty ( ) || exprs. len ( ) > 3 {
158+ return plan_err ! ( "generate_series function requires 1 to 3 arguments" ) ;
154159 }
155160
156- if exprs. len ( ) != 2 {
157- return plan_err ! ( "generate_series expects 2 arguments" ) ;
161+ let mut normalize_args = Vec :: new ( ) ;
162+ for expr in exprs {
163+ match expr {
164+ Expr :: Literal ( ScalarValue :: Null ) => { }
165+ Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => normalize_args. push ( * n) ,
166+ _ => return plan_err ! ( "First argument must be an integer literal" ) ,
167+ } ;
158168 }
159169
160- let start = match & exprs[ 0 ] {
161- Expr :: Literal ( ScalarValue :: Null ) => None ,
162- Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => Some ( * n) ,
163- _ => return plan_err ! ( "First argument must be an integer literal" ) ,
164- } ;
165-
166- let end = match & exprs[ 1 ] {
167- Expr :: Literal ( ScalarValue :: Null ) => None ,
168- Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => Some ( * n) ,
169- _ => return plan_err ! ( "Second argument must be an integer literal" ) ,
170- } ;
171-
172170 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
173171 "value" ,
174172 DataType :: Int64 ,
175173 false ,
176174 ) ] ) ) ;
177175
178- Ok ( Arc :: new ( GenerateSeriesTable { schema, start, end } ) )
176+ if normalize_args. len ( ) != exprs. len ( ) {
177+ // contain null
178+ return Ok ( Arc :: new ( GenerateSeriesTable {
179+ schema,
180+ args : GenSeriesArgs :: ContainsNull ,
181+ } ) ) ;
182+ }
183+
184+ let ( start, end, step) = match & normalize_args[ ..] {
185+ [ end] => ( 0 , * end, 1 ) ,
186+ [ start, end] => ( * start, * end, 1 ) ,
187+ [ start, end, step] => ( * start, * end, * step) ,
188+ _ => {
189+ return plan_err ! ( "generate_series function requires 1 to 3 arguments" ) ;
190+ }
191+ } ;
192+
193+ if start > end && step > 0 {
194+ return plan_err ! ( "start is bigger than end, but increment is positive: cannot generate infinite series" ) ;
195+ }
196+
197+ if start < end && step < 0 {
198+ return plan_err ! ( "start is smaller than end, but increment is negative: cannot generate infinite series" ) ;
199+ }
200+
201+ if step == 0 {
202+ return plan_err ! ( "step cannot be zero" ) ;
203+ }
204+
205+ Ok ( Arc :: new ( GenerateSeriesTable {
206+ schema,
207+ args : GenSeriesArgs :: AllNotNullArgs { start, end, step } ,
208+ } ) )
179209 }
180210}
0 commit comments