Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT {
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);

private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)";
"SELECT * FROM FORECAST(model_id=>'%s', targets=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)";

@BeforeClass
public static void setUp() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@
import java.sql.Statement;

import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
public class AINodeForecastIT {

private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM db.AI) ORDER BY time)";
"SELECT * FROM FORECAST("
+ "model_id=>'%s', "
+ "targets=>(SELECT time, s%d FROM db.AI WHERE time<%d ORDER BY time DESC LIMIT %d) ORDER BY time, "
+ "output_start_time=>%d, "
+ "output_length=>%d, "
+ "output_interval=>%d, "
+ "timecol=>'%s'"
+ ")";

@BeforeClass
public static void setUp() throws Exception {
Expand All @@ -55,7 +63,7 @@ public static void setUp() throws Exception {
statement.execute("CREATE DATABASE db");
statement.execute(
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
for (int i = 0; i < 2880; i++) {
for (int i = 0; i < 5760; i++) {
statement.execute(
String.format(
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
Expand All @@ -81,18 +89,100 @@ public void forecastTableFunctionTest() throws SQLException {

public void forecastTableFunctionTest(
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
// Invoke call inference for specified models, there should exist result.
// Invoke forecast table function for specified models, there should exist result.
for (int i = 0; i < 4; i++) {
String forecastTableFunctionSQL =
String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), i);
String.format(
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
modelInfo.getModelId(),
i,
5760,
2880,
5760,
96,
1,
"time");
try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) {
int count = 0;
while (resultSet.next()) {
count++;
}
// Ensure the call inference return results
// Ensure the forecast sentence return results
Assert.assertTrue(count > 0);
}
}
}

@Test
public void forecastTableFunctionErrorTest() throws SQLException {
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
forecastTableFunctionErrorTest(statement, modelInfo);
}
}
}

public void forecastTableFunctionErrorTest(
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
// OUTPUT_START_TIME error
String invalidOutputStartTimeSQL =
String.format(
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
modelInfo.getModelId(),
0,
5760,
2880,
5759,
96,
1,
"time");
errorTest(
statement,
invalidOutputStartTimeSQL,
"701: The OUTPUT_START_TIME should be greater than the maximum timestamp of target time series. Expected greater than [5759] but found [5759].");

// OUTPUT_LENGTH error
String invalidOutputLengthSQL =
String.format(
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
modelInfo.getModelId(),
0,
5760,
2880,
5760,
0,
1,
"time");
errorTest(statement, invalidOutputLengthSQL, "701: OUTPUT_LENGTH should be greater than 0");

// OUTPUT_INTERVAL error
String invalidOutputIntervalSQL =
String.format(
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
modelInfo.getModelId(),
0,
5760,
2880,
5760,
96,
-1,
"time");
errorTest(statement, invalidOutputIntervalSQL, "701: OUTPUT_INTERVAL should be greater than 0");

// TIMECOL error
String invalidTimecolSQL2 =
String.format(
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
modelInfo.getModelId(),
0,
5760,
2880,
5760,
96,
1,
"s0");
errorTest(
statement, invalidTimecolSQL2, "701: The type of the column [s0] is not as expected.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,11 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
input_ids = input_ids[
:,
-(attention_mask.shape[1] - past_length)
* self.config.input_token_len :,
]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
Expand All @@ -623,9 +627,10 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[
:, -(input_ids.shape[1] // self.config.input_token_len) :
]
token_num = (
input_ids.shape[1] + self.config.input_token_len - 1
) // self.config.input_token_len
position_ids = position_ids[:, -token_num:]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,11 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
input_ids = input_ids[
:,
-(attention_mask.shape[1] - past_length)
* self.config.input_token_len :,
]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
import org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ClassifyTableFunction;
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.PatternMatchTableFunction;
import org.apache.iotdb.udf.api.relational.TableFunction;
Expand All @@ -42,7 +43,8 @@ public enum TableBuiltinTableFunction {
VARIATION("variation"),
CAPACITY("capacity"),
FORECAST("forecast"),
PATTERN_MATCH("pattern_match");
PATTERN_MATCH("pattern_match"),
CLASSIFY("classify");

private final String functionName;

Expand Down Expand Up @@ -86,6 +88,8 @@ public static TableFunction getBuiltinTableFunction(String functionName) {
return new CapacityTableFunction();
case "forecast":
return new ForecastTableFunction();
case "classify":
return new ClassifyTableFunction();
default:
throw new UnsupportedOperationException("Unsupported table function: " + functionName);
}
Expand Down
Loading