Skip to content

Commit e51bcdd

Browse files
authored
[AutoScheduler] Support string processing to records (#7144)
* [AutoScheduler] Support string processing to records * doc * remove log
1 parent 5a61089 commit e51bcdd

File tree

4 files changed

+72
-8
lines changed

4 files changed

+72
-8
lines changed

include/tvm/auto_scheduler/measure_record.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
namespace tvm {
3535
namespace auto_scheduler {
3636

37+
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*)
38+
3739
/*! \brief Callback for logging the input and results of measurements to file */
3840
class RecordToFileNode : public MeasureCallbackNode {
3941
public:
@@ -116,9 +118,11 @@ class RecordReader : public ObjectRef {
116118
* \param os A pointer to a output stream.
117119
* \param inputs The MeasureInputs to be written.
118120
* \param results The MeasureResults to be written.
121+
* \param log_version The log version for the given record.
119122
*/
120123
void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
121-
const Array<MeasureResult>& results);
124+
const Array<MeasureResult>& results,
125+
const std::string log_version = AUTO_SCHEDULER_LOG_VERSION);
122126

123127
/*!
124128
* \brief Read one measure record from a string.

python/tvm/auto_scheduler/measure_record.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,43 @@ def __iter__(self):
9898
yield ret[0], ret[1] # (input, result)
9999

100100

101+
def load_record_from_string(record):
102+
"""
103+
Load the measure record from string.
104+
105+
Parameters
106+
----------
107+
record: str
108+
A record string, including the serialized MeausreInput and MeasureResult.
109+
110+
Returns
111+
-------
112+
ret: Tuple[MeasureInput, MeasureResult]
113+
A tuple of MeasureInput, MeasureResult.
114+
"""
115+
return _ffi_api.ReadMeasureRecord(record)
116+
117+
118+
def dump_record_to_string(inp, res):
119+
"""
120+
Dump the measure record to a string.
121+
122+
Parameters
123+
----------
124+
inp: MeasureInput
125+
The measure input.
126+
127+
res: MeasureResult
128+
The measure result.
129+
130+
Returns
131+
-------
132+
ret: str
133+
The dumped string.
134+
"""
135+
return _ffi_api.WriteMeasureRecords(inp, res)
136+
137+
101138
def load_records(filename):
102139
"""
103140
Load measurement records from a file.

src/auto_scheduler/measure_record.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,22 +279,20 @@ namespace auto_scheduler {
279279
TVM_REGISTER_OBJECT_TYPE(RecordToFileNode);
280280
TVM_REGISTER_OBJECT_TYPE(RecordReaderNode);
281281

282-
const std::string AUTO_SCHEDULER_LOG_VERSION = "v0.4"; // NOLINT(*)
283-
284282
RecordToFile::RecordToFile(String filename) {
285283
auto node = make_object<RecordToFileNode>();
286284
node->filename = std::move(filename);
287285
data_ = std::move(node);
288286
}
289287

290288
void WriteMeasureRecords(std::ostream* os, const Array<MeasureInput>& inputs,
291-
const Array<MeasureResult>& results) {
289+
const Array<MeasureResult>& results, const std::string log_version) {
292290
dmlc::JSONWriter writer(os);
293291
for (size_t i = 0; i < inputs.size(); ++i) {
294292
writer.BeginObject(false);
295293
writer.WriteObjectKeyValue("i", *inputs[i].operator->());
296294
writer.WriteObjectKeyValue("r", *results[i].operator->());
297-
writer.WriteObjectKeyValue("v", AUTO_SCHEDULER_LOG_VERSION);
295+
writer.WriteObjectKeyValue("v", log_version);
298296
writer.EndObject();
299297
*os << "\n";
300298
}
@@ -398,6 +396,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RecordReaderReadNext").set_body_typed([](Rec
398396
}
399397
});
400398

399+
TVM_REGISTER_GLOBAL("auto_scheduler.ReadMeasureRecord").set_body_typed([](const std::string& str) {
400+
auto inp = make_object<MeasureInputNode>();
401+
auto res = make_object<MeasureResultNode>();
402+
std::string log_version;
403+
ReadMeasureRecord(str, inp.get(), res.get(), &log_version);
404+
return Array<ObjectRef>{ObjectRef(inp), ObjectRef(res)};
405+
});
406+
407+
TVM_REGISTER_GLOBAL("auto_scheduler.WriteMeasureRecords")
408+
.set_body_typed([](MeasureInput inp, MeasureResult res) {
409+
auto inps = Array<MeasureInput>({inp});
410+
auto ress = Array<MeasureResult>({res});
411+
std::ostringstream ss;
412+
WriteMeasureRecords(&ss, inps, ress);
413+
return String(ss.str());
414+
});
415+
401416
TVM_REGISTER_GLOBAL("auto_scheduler.SaveRecords")
402417
.set_body_typed([](String filename, Array<MeasureInput> in, Array<MeasureResult> res) {
403418
std::ofstream ofs(filename, std::ofstream::app);

tests/python/unittest/test_auto_scheduler_measure.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,19 @@ def record_common(dag, s):
3434
inp = auto_scheduler.measure.MeasureInput(task, s)
3535
res = auto_scheduler.measure.MeasureResult([0.1], 0, "", 0.2, 1)
3636

37+
# Test in-memory record processing.
38+
record_str = auto_scheduler.measure_record.dump_record_to_string(inp, res)
39+
r_inp, r_res = auto_scheduler.measure_record.load_record_from_string(record_str)
40+
# Only check the workload_key for simplification.
41+
assert inp.task.workload_key == r_inp.task.workload_key
42+
assert str(res) == str(r_res)
43+
44+
# Test file-based record processing.
3745
with tempfile.NamedTemporaryFile() as fp:
3846
auto_scheduler.save_records(fp.name, [inp], [res])
3947

4048
log_reader = auto_scheduler.RecordReader(fp.name)
41-
inputs, results = log_reader.read_lines()
49+
inputs, _ = log_reader.read_lines()
4250
assert len(inputs) == 1
4351

4452
s1 = dag.infer_bound_from_state(s)
@@ -180,7 +188,7 @@ def test_recover_measure_input():
180188
auto_scheduler.save_records(fp.name, [inp], [res])
181189

182190
log_reader = auto_scheduler.RecordReader(fp.name)
183-
inputs, results = log_reader.read_lines()
191+
inputs, _ = log_reader.read_lines()
184192
assert len(inputs) == 1
185193

186194
raw_inp = inputs[0]
@@ -266,7 +274,7 @@ def test_measure_target_host():
266274
auto_scheduler.save_records(fp.name, [inp], [res])
267275

268276
log_reader = auto_scheduler.RecordReader(fp.name)
269-
inputs, results = log_reader.read_lines()
277+
inputs, _ = log_reader.read_lines()
270278
assert len(inputs) == 1
271279

272280
raw_inp = inputs[0]

0 commit comments

Comments
 (0)