@@ -22,22 +22,25 @@ use async_trait::async_trait;
2222use datafusion_catalog:: Session ;
2323use datafusion_catalog:: TableFunctionImpl ;
2424use datafusion_catalog:: TableProvider ;
25- use datafusion_common:: { not_impl_err , plan_err, Result , ScalarValue } ;
25+ use datafusion_common:: { plan_err, Result , ScalarValue } ;
2626use datafusion_expr:: { Expr , TableType } ;
2727use datafusion_physical_plan:: memory:: { LazyBatchGenerator , LazyMemoryExec } ;
2828use datafusion_physical_plan:: ExecutionPlan ;
2929use parking_lot:: RwLock ;
3030use std:: fmt;
3131use std:: sync:: Arc ;
3232
33+ #[ derive( Debug , Clone ) ]
34+ enum GenSeriesArgs {
35+ ContainsNull ,
36+ AllNotNullArgs { start : i64 , end : i64 , step : i64 } ,
37+ }
38+
3339/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive)
3440#[ derive( Debug , Clone ) ]
3541struct GenerateSeriesTable {
3642 schema : SchemaRef ,
37- // None if input is Null
38- start : Option < i64 > ,
39- // None if input is Null
40- end : Option < i64 > ,
43+ args : GenSeriesArgs ,
4144}
4245
4346/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive)
@@ -46,10 +49,21 @@ struct GenerateSeriesState {
4649 schema : SchemaRef ,
4750 start : i64 , // Kept for display
4851 end : i64 ,
52+ step : i64 ,
4953 batch_size : usize ,
5054
5155 /// Tracks current position when generating table
52- current : i64 ,
56+ next : i64 ,
57+ }
58+
59+ impl GenerateSeriesState {
60+ fn reach_end ( & self , val : i64 ) -> bool {
61+ if self . step > 0 {
62+ return val > self . end ;
63+ }
64+
65+ return val < self . end ;
66+ }
5367}
5468
5569/// Detail to display for 'Explain' plan
@@ -65,19 +79,19 @@ impl fmt::Display for GenerateSeriesState {
6579
6680impl LazyBatchGenerator for GenerateSeriesState {
6781 fn generate_next_batch ( & mut self ) -> Result < Option < RecordBatch > > {
68- // Check if we've reached the end
69- if self . current > self . end {
82+ let mut buf = Vec :: with_capacity ( self . batch_size ) ;
83+ while buf. len ( ) < self . batch_size && !self . reach_end ( self . next ) {
84+ buf. push ( self . next ) ;
85+ self . next += self . step ;
86+ }
87+ let array = Int64Array :: from ( buf) ;
88+
89+ if array. len ( ) == 0 {
7090 return Ok ( None ) ;
7191 }
7292
73- // Construct batch
74- let batch_end = ( self . current + self . batch_size as i64 - 1 ) . min ( self . end ) ;
75- let array = Int64Array :: from_iter_values ( self . current ..=batch_end) ;
7693 let batch = RecordBatch :: try_new ( self . schema . clone ( ) , vec ! [ Arc :: new( array) ] ) ?;
7794
78- // Update current position for next batch
79- self . current = batch_end + 1 ;
80-
8195 Ok ( Some ( batch) )
8296 }
8397}
@@ -104,77 +118,90 @@ impl TableProvider for GenerateSeriesTable {
104118 _limit : Option < usize > ,
105119 ) -> Result < Arc < dyn ExecutionPlan > > {
106120 let batch_size = state. config_options ( ) . execution . batch_size ;
107- match ( self . start , self . end ) {
108- ( Some ( start) , Some ( end) ) => {
109- if start > end {
110- return plan_err ! (
111- "End value must be greater than or equal to start value"
112- ) ;
113- }
114-
115- Ok ( Arc :: new ( LazyMemoryExec :: try_new (
116- self . schema . clone ( ) ,
117- vec ! [ Arc :: new( RwLock :: new( GenerateSeriesState {
118- schema: self . schema. clone( ) ,
119- start,
120- end,
121- current: start,
122- batch_size,
123- } ) ) ] ,
124- ) ?) )
125- }
126- _ => {
127- // Either start or end is None, return a generator that outputs 0 rows
128- Ok ( Arc :: new ( LazyMemoryExec :: try_new (
129- self . schema . clone ( ) ,
130- vec ! [ Arc :: new( RwLock :: new( GenerateSeriesState {
131- schema: self . schema. clone( ) ,
132- start: 0 ,
133- end: 0 ,
134- current: 1 ,
135- batch_size,
136- } ) ) ] ,
137- ) ?) )
138- }
139- }
121+
122+ let state = match self . args {
123+ // if args have null, then return 0 row
124+ GenSeriesArgs :: ContainsNull => GenerateSeriesState {
125+ schema : self . schema . clone ( ) ,
126+ start : 0 ,
127+ end : 0 ,
128+ step : 1 ,
129+ next : 1 ,
130+ batch_size,
131+ } ,
132+ GenSeriesArgs :: AllNotNullArgs { start, end, step } => GenerateSeriesState {
133+ schema : self . schema . clone ( ) ,
134+ start,
135+ end,
136+ step,
137+ next : start,
138+ batch_size,
139+ } ,
140+ } ;
141+
142+ Ok ( Arc :: new ( LazyMemoryExec :: try_new (
143+ self . schema . clone ( ) ,
144+ vec ! [ Arc :: new( RwLock :: new( state) ) ] ,
145+ ) ?) )
140146 }
141147}
142148
143149#[ derive( Debug ) ]
144150pub struct GenerateSeriesFunc { }
145151
146152impl TableFunctionImpl for GenerateSeriesFunc {
147- // Check input `exprs` type and number. Input validity check (e.g. start <= end)
148- // will be performed in `TableProvider::scan`
149153 fn call ( & self , exprs : & [ Expr ] ) -> Result < Arc < dyn TableProvider > > {
150- // TODO: support 1 or 3 arguments following DuckDB:
151- // <https://duckdb.org/docs/sql/functions/list#generate_series>
152- if exprs. len ( ) == 3 || exprs. len ( ) == 1 {
153- return not_impl_err ! ( "generate_series does not support 1 or 3 arguments" ) ;
154+ if exprs. len ( ) == 0 || exprs. len ( ) > 3 {
155+ return plan_err ! ( "generate_series function requires 1 to 3 arguments" ) ;
154156 }
155157
156- if exprs. len ( ) != 2 {
157- return plan_err ! ( "generate_series expects 2 arguments" ) ;
158+ let mut normalize_args = Vec :: new ( ) ;
159+ for expr in exprs {
160+ match expr {
161+ Expr :: Literal ( ScalarValue :: Null ) => { }
162+ Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => normalize_args. push ( * n) ,
163+ _ => return plan_err ! ( "First argument must be an integer literal" ) ,
164+ } ;
158165 }
159166
160- let start = match & exprs[ 0 ] {
161- Expr :: Literal ( ScalarValue :: Null ) => None ,
162- Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => Some ( * n) ,
163- _ => return plan_err ! ( "First argument must be an integer literal" ) ,
164- } ;
165-
166- let end = match & exprs[ 1 ] {
167- Expr :: Literal ( ScalarValue :: Null ) => None ,
168- Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => Some ( * n) ,
169- _ => return plan_err ! ( "Second argument must be an integer literal" ) ,
170- } ;
171-
172167 let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
173168 "value" ,
174169 DataType :: Int64 ,
175170 false ,
176171 ) ] ) ) ;
177172
178- Ok ( Arc :: new ( GenerateSeriesTable { schema, start, end } ) )
173+ if normalize_args. len ( ) != exprs. len ( ) {
174+ // contain null
175+ return Ok ( Arc :: new ( GenerateSeriesTable {
176+ schema,
177+ args : GenSeriesArgs :: ContainsNull ,
178+ } ) ) ;
179+ }
180+
181+ let ( start, end, step) = match & normalize_args[ ..] {
182+ [ end] => ( 0 , * end, 1 ) ,
183+ [ start, end] => ( * start, * end, 1 ) ,
184+ [ start, end, step] => ( * start, * end, * step) ,
185+ _ => {
186+ return plan_err ! ( "generate_series function requires 1 to 3 arguments" ) ;
187+ }
188+ } ;
189+
190+ if start > end && step > 0 {
191+ return plan_err ! ( "start is bigger than end, but increment is positive: cannot generate infinite series" ) ;
192+ }
193+
194+ if start < end && step < 0 {
195+ return plan_err ! ( "start is smaller than end, but increment is negative: cannot generate infinite series" ) ;
196+ }
197+
198+ if step == 0 {
199+ return plan_err ! ( "step cannot be zero" ) ;
200+ }
201+
202+ Ok ( Arc :: new ( GenerateSeriesTable {
203+ schema,
204+ args : GenSeriesArgs :: AllNotNullArgs { start, end, step } ,
205+ } ) )
179206 }
180207}
0 commit comments