Skip to content

Commit bc098c3

Browse files
ochafikochafik
and
ochafik
authored
minja: sync (qwen3) (#13573)
* minja: sync google/minja@f06140f - google/minja#67 (@grf53) - google/minja#66 (@taha-yassine) - google/minja#63 (@grf53) - google/minja#58 --------- Co-authored-by: ochafik <ochafik@google.com>
1 parent c6a2c9e commit bc098c3

File tree

2 files changed

+78
-41
lines changed

2 files changed

+78
-41
lines changed

common/minja/chat-template.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
#include <chrono>
1414
#include <cstddef>
1515
#include <cstdio>
16+
#include <ctime>
1617
#include <exception>
1718
#include <iomanip>
1819
#include <memory>
1920
#include <sstream>
21+
#include <stdexcept>
2022
#include <string>
2123
#include <vector>
2224

@@ -393,8 +395,8 @@ class chat_template {
393395

394396
for (const auto & message_ : adjusted_messages) {
395397
auto message = message_;
396-
if (!message.contains("role") || !message.contains("content")) {
397-
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
398+
if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) {
399+
throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump());
398400
}
399401
std::string role = message.at("role");
400402

@@ -415,7 +417,6 @@ class chat_template {
415417
}
416418
}
417419
if (polyfill_tool_calls) {
418-
auto content = message.at("content");
419420
auto tool_calls = json::array();
420421
for (const auto & tool_call : message.at("tool_calls")) {
421422
if (tool_call.at("type") != "function") {
@@ -434,8 +435,11 @@ class chat_template {
434435
auto obj = json {
435436
{"tool_calls", tool_calls},
436437
};
437-
if (!content.is_null() && !content.empty()) {
438-
obj["content"] = content;
438+
if (message.contains("content")) {
439+
auto content = message.at("content");
440+
if (!content.is_null() && !content.empty()) {
441+
obj["content"] = content;
442+
}
439443
}
440444
message["content"] = obj.dump(2);
441445
message.erase("tool_calls");

common/minja/minja.hpp

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <algorithm>
1212
#include <cctype>
1313
#include <cstddef>
14+
#include <cstdint>
1415
#include <cmath>
1516
#include <exception>
1617
#include <functional>
@@ -233,7 +234,7 @@ class Value : public std::enable_shared_from_this<Value> {
233234
}
234235
} else if (is_object()) {
235236
if (!index.is_hashable())
236-
throw std::runtime_error("Unashable type: " + index.dump());
237+
throw std::runtime_error("Unhashable type: " + index.dump());
237238
auto it = object_->find(index.primitive_);
238239
if (it == object_->end())
239240
throw std::runtime_error("Key not found: " + index.dump());
@@ -252,7 +253,7 @@ class Value : public std::enable_shared_from_this<Value> {
252253
auto index = key.get<int>();
253254
return array_->at(index < 0 ? array_->size() + index : index);
254255
} else if (object_) {
255-
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
256+
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
256257
auto it = object_->find(key.primitive_);
257258
if (it == object_->end()) return Value();
258259
return it->second;
@@ -261,7 +262,7 @@ class Value : public std::enable_shared_from_this<Value> {
261262
}
262263
void set(const Value& key, const Value& value) {
263264
if (!object_) throw std::runtime_error("Value is not an object: " + dump());
264-
if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
265+
if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
265266
(*object_)[key.primitive_] = value;
266267
}
267268
Value call(const std::shared_ptr<Context> & context, ArgumentsValue & args) const {
@@ -398,7 +399,7 @@ class Value : public std::enable_shared_from_this<Value> {
398399
}
399400
return false;
400401
} else if (object_) {
401-
if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump());
402+
if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump());
402403
return object_->find(value.primitive_) != object_->end();
403404
} else {
404405
throw std::runtime_error("contains can only be called on arrays and objects: " + dump());
@@ -416,7 +417,7 @@ class Value : public std::enable_shared_from_this<Value> {
416417
return const_cast<Value*>(this)->at(index);
417418
}
418419
Value& at(const Value & index) {
419-
if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump());
420+
if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump());
420421
if (is_array()) return array_->at(index.get<int>());
421422
if (is_object()) return object_->at(index.primitive_);
422423
throw std::runtime_error("Value is not an array or object: " + dump());
@@ -676,8 +677,8 @@ class Expression {
676677
class VariableExpr : public Expression {
677678
std::string name;
678679
public:
679-
VariableExpr(const Location & location, const std::string& n)
680-
: Expression(location), name(n) {}
680+
VariableExpr(const Location & loc, const std::string& n)
681+
: Expression(loc), name(n) {}
681682
std::string get_name() const { return name; }
682683
Value do_evaluate(const std::shared_ptr<Context> & context) const override {
683684
if (!context->contains(name)) {
@@ -1200,9 +1201,9 @@ class DictExpr : public Expression {
12001201

12011202
class SliceExpr : public Expression {
12021203
public:
1203-
std::shared_ptr<Expression> start, end;
1204-
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e)
1205-
: Expression(loc), start(std::move(s)), end(std::move(e)) {}
1204+
std::shared_ptr<Expression> start, end, step;
1205+
SliceExpr(const Location & loc, std::shared_ptr<Expression> && s, std::shared_ptr<Expression> && e, std::shared_ptr<Expression> && st = nullptr)
1206+
: Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {}
12061207
Value do_evaluate(const std::shared_ptr<Context> &) const override {
12071208
throw std::runtime_error("SliceExpr not implemented");
12081209
}
@@ -1219,18 +1220,35 @@ class SubscriptExpr : public Expression {
12191220
if (!index) throw std::runtime_error("SubscriptExpr.index is null");
12201221
auto target_value = base->evaluate(context);
12211222
if (auto slice = dynamic_cast<SliceExpr*>(index.get())) {
1222-
auto start = slice->start ? slice->start->evaluate(context).get<int64_t>() : 0;
1223-
auto end = slice->end ? slice->end->evaluate(context).get<int64_t>() : (int64_t) target_value.size();
1223+
auto len = target_value.size();
1224+
auto wrap = [len](int64_t i) -> int64_t {
1225+
if (i < 0) {
1226+
return i + len;
1227+
}
1228+
return i;
1229+
};
1230+
int64_t step = slice->step ? slice->step->evaluate(context).get<int64_t>() : 1;
1231+
if (!step) {
1232+
throw std::runtime_error("slice step cannot be zero");
1233+
}
1234+
int64_t start = slice->start ? wrap(slice->start->evaluate(context).get<int64_t>()) : (step < 0 ? len - 1 : 0);
1235+
int64_t end = slice->end ? wrap(slice->end->evaluate(context).get<int64_t>()) : (step < 0 ? -1 : len);
12241236
if (target_value.is_string()) {
12251237
std::string s = target_value.get<std::string>();
1226-
if (start < 0) start = s.size() + start;
1227-
if (end < 0) end = s.size() + end;
1228-
return s.substr(start, end - start);
1238+
1239+
std::string result;
1240+
if (start < end && step == 1) {
1241+
result = s.substr(start, end - start);
1242+
} else {
1243+
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
1244+
result += s[i];
1245+
}
1246+
}
1247+
return result;
1248+
12291249
} else if (target_value.is_array()) {
1230-
if (start < 0) start = target_value.size() + start;
1231-
if (end < 0) end = target_value.size() + end;
12321250
auto result = Value::array();
1233-
for (auto i = start; i < end; ++i) {
1251+
for (int64_t i = start; step > 0 ? i < end : i > end; i += step) {
12341252
result.push_back(target_value.at(i));
12351253
}
12361254
return result;
@@ -1305,6 +1323,8 @@ class BinaryOpExpr : public Expression {
13051323
if (name == "iterable") return l.is_iterable();
13061324
if (name == "sequence") return l.is_array();
13071325
if (name == "defined") return !l.is_null();
1326+
if (name == "true") return l.to_bool();
1327+
if (name == "false") return !l.to_bool();
13081328
throw std::runtime_error("Unknown type for 'is' operator: " + name);
13091329
};
13101330
auto value = eval();
@@ -1520,6 +1540,10 @@ class MethodCallExpr : public Expression {
15201540
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
15211541
auto suffix = vargs.args[0].get<std::string>();
15221542
return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
1543+
} else if (method->get_name() == "startswith") {
1544+
vargs.expectArgs("startswith method", {1, 1}, {0, 0});
1545+
auto prefix = vargs.args[0].get<std::string>();
1546+
return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin());
15231547
} else if (method->get_name() == "title") {
15241548
vargs.expectArgs("title method", {0, 0}, {0, 0});
15251549
auto res = str;
@@ -2082,28 +2106,37 @@ class Parser {
20822106

20832107
while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) {
20842108
if (!consumeToken("[").empty()) {
2085-
std::shared_ptr<Expression> index;
2109+
std::shared_ptr<Expression> index;
2110+
auto slice_loc = get_location();
2111+
std::shared_ptr<Expression> start, end, step;
2112+
bool has_first_colon = false, has_second_colon = false;
2113+
2114+
if (!peekSymbols({ ":" })) {
2115+
start = parseExpression();
2116+
}
2117+
2118+
if (!consumeToken(":").empty()) {
2119+
has_first_colon = true;
2120+
if (!peekSymbols({ ":", "]" })) {
2121+
end = parseExpression();
2122+
}
20862123
if (!consumeToken(":").empty()) {
2087-
auto slice_end = parseExpression();
2088-
index = std::make_shared<SliceExpr>(slice_end->location, nullptr, std::move(slice_end));
2089-
} else {
2090-
auto slice_start = parseExpression();
2091-
if (!consumeToken(":").empty()) {
2092-
consumeSpaces();
2093-
if (peekSymbols({ "]" })) {
2094-
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), nullptr);
2095-
} else {
2096-
auto slice_end = parseExpression();
2097-
index = std::make_shared<SliceExpr>(slice_start->location, std::move(slice_start), std::move(slice_end));
2098-
}
2099-
} else {
2100-
index = std::move(slice_start);
2124+
has_second_colon = true;
2125+
if (!peekSymbols({ "]" })) {
2126+
step = parseExpression();
21012127
}
21022128
}
2103-
if (!index) throw std::runtime_error("Empty index in subscript");
2104-
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
2129+
}
2130+
2131+
if ((has_first_colon || has_second_colon) && (start || end || step)) {
2132+
index = std::make_shared<SliceExpr>(slice_loc, std::move(start), std::move(end), std::move(step));
2133+
} else {
2134+
index = std::move(start);
2135+
}
2136+
if (!index) throw std::runtime_error("Empty index in subscript");
2137+
if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript");
21052138

2106-
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
2139+
value = std::make_shared<SubscriptExpr>(value->location, std::move(value), std::move(index));
21072140
} else if (!consumeToken(".").empty()) {
21082141
auto identifier = parseIdentifier();
21092142
if (!identifier) throw std::runtime_error("Expected identifier in subscript");

0 commit comments

Comments
 (0)