Skip to content

Commit

Permalink
use multi-threads in cpu auto-scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoyaoding committed Aug 3, 2023
1 parent d4fcf33 commit 1a16bec
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 4 deletions.
9 changes: 9 additions & 0 deletions python/hidet/drivers/build_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,17 @@ def get_signature(t: TensorNode, device: str) -> TensorSignature:
device=device, dtype=t.type.dtype.name, shape=[int(v) if is_constant(v) else str(v) for v in t.shape]
)

# extract the task name
from hidet.graph.ops.fusion.fused_operator import FusedTask

if isinstance(task, FusedTask):
task_name = 'fused_{}'.format(task.attrs['fused_ops'].replace(' ', '_'))
else:
task_name = task.name

# generate meta data
meta = TaskMetaData(
name=task_name,
symbols=[v.name for v in task.symbols],
inputs=[get_signature(t, input_device) for t in task.inputs],
outputs=[get_signature(t, output_device) for t in task.outputs],
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/ir/builders/stmt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Sequence
from typing import Union, Sequence

from hidet.ir.stmt import Stmt, ForStmt, IfStmt, EvaluateStmt, SeqStmt, LetStmt, ForMappingStmt, ForStmtAttr
from hidet.ir.expr import Expr, Var, var, convert
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/ir/schedulers/cpu/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from hidet.ir.builders import FunctionBuilder
from hidet.ir.compute import TensorNode, GridCompute
from hidet.ir.expr import Var, convert, call
from hidet.ir.expr import Var, call
from hidet.ir.tools import rewrite
from hidet.ir.stmt import Stmt, BufferStoreStmt, EvaluateStmt
from hidet.ir.schedulers.base import AutoScheduler, ComputeExprLower
Expand All @@ -35,7 +35,7 @@ def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode,
fb.extend_params(params)

iter_names = [f'i{i}' for i in range(len(node.shape))]
with fb.for_loop('w', extent=prod(node.shape)) as w:
with fb.for_loop('w', extent=prod(node.shape), attr='p') as w:
with fb.for_mapping(iter_names, row_spatial(*node.shape), worker=w) as task_index:
out_param: Var = param_map[node]
compute_lower = ComputeExprLower(node.value, param_map=param_map)
Expand Down
30 changes: 29 additions & 1 deletion python/hidet/runtime/compiled_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from hidet.runtime.compiled_task import CompiledTask, TensorSignature, _check_inputs
from hidet.runtime.storage import Storage
from hidet.ffi import runtime_api
from hidet.utils import prod
from hidet.utils.py import prod, median
from hidet.utils.trace_utils import TraceEventEmitter

ModelExecutionHook = Callable[[int, List['Tensor'], List['Tensor']], None]

Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
self.constant_outputs: List[Union[None, Tensor]] = []

# runtime state
self.working_dir: str = hidet.utils.cache_file('graphs', self.meta.graph_hash)
self.dispatch_table_path = hidet.utils.cache_file('graphs', self.meta.graph_hash, 'dispatch_table.txt')
self.dispatch_table: Dict[Tuple[int, ...], Array] = {}
self.cuda_workspace: Optional[Storage] = None
Expand Down Expand Up @@ -258,23 +260,49 @@ def _run_slow_path(self, inputs, symbol_dims: Tuple[int, ...]):
index2tensor[exe.inputs_index[i]] = inputs[i]
for i in range(len(self.weights)):
index2tensor[exe.weights_index[i]] = self.weights[i]

best_candidates = [-1 for _ in range(len(self.compiled_tasks))]
trace_emitter = TraceEventEmitter({'graph': self.graph_string})
for inst in exe.instructions:
# prepare inputs and kernel
node_inputs = [index2tensor[i] for i in inst.inputs]
node_kernel: CompiledTask = self.compiled_tasks[inst.task_idx]

# run the kernel
node_outputs = node_kernel.run_async(node_inputs)

# record outputs
for i, output_index in enumerate(inst.outputs):
index2tensor[output_index] = node_outputs[i]

# record best candidate for this kernel
best_candidates[inst.task_idx] = node_kernel.pick_best_candidate(node_inputs, node_outputs)

# record trace events
trace_emitter.append(
name=node_kernel.meta_data.name,
duration_us=int(median(node_kernel.profile(*node_inputs, *node_outputs)) * 1000),
args={
'name': node_kernel.meta_data.name,
'inputs': ['{}{}'.format(x.dtype, x.shape) for x in node_kernel.meta_data.inputs],
'outputs': ['{}{}'.format(x.dtype, x.shape) for x in node_kernel.meta_data.outputs],
},
)

# free tensors that are no longer needed
for idx in inst.free:
del index2tensor[idx]

outputs = [index2tensor[i] for i in exe.outputs_index]

# update the dispatch table
self._update_symbol_table(symbol_dims, best_candidates)

# save the trace
trace_filename = 'trace{}.json'.format('_'.join(str(x) for x in symbol_dims))
with open(os.path.join(self.working_dir, trace_filename), 'w') as f:
trace_emitter.save(f)

return outputs

def run_async(self, inputs):
Expand Down
1 change: 1 addition & 0 deletions python/hidet/runtime/compiled_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TensorSignature:

@dataclass
class TaskMetaData:
name: str
symbols: List[str]
inputs: List[TensorSignature]
outputs: List[TensorSignature]
Expand Down
8 changes: 8 additions & 0 deletions python/hidet/utils/py.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def prod(seq: Iterable):
return c


def median(seq: Iterable):
seq = list(seq)
if len(seq) == 0:
return None
else:
return sorted(seq)[len(seq) // 2]


def clip(
x: Union[int, float], low: Optional[Union[int, float]], high: Optional[Union[int, float]]
) -> Union[int, float]:
Expand Down
55 changes: 55 additions & 0 deletions python/hidet/utils/trace_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any, Dict, List
from dataclasses import dataclass, asdict
import json


@dataclass
class Event:
name: str
cat: str
ph: str
ts: int
pid: int
tid: int
args: Dict[str, Any]


@dataclass
class TraceEvents:
traceEvents: List[Event]
displayTimeUnit: str = 'ms'
otherData: Dict[str, Any] = None


class TraceEventEmitter:
def __init__(self, other_data: Dict[str, Any] = None):
self.events: List[Event] = []
self.otherData: Dict[str, Any] = other_data if other_data is not None else {}

self.current_ts = 0

def append(self, name: str, duration_us: int, args: Dict[str, Any] = None):
self.events.append(
Event(
name=name, cat='kernel', ph='B', ts=self.current_ts, pid=0, tid=0, args=args if args is not None else {}
)
)
self.current_ts += duration_us
self.events.append(
Event(
name=name, cat='kernel', ph='E', ts=self.current_ts, pid=0, tid=0, args=args if args is not None else {}
)
)

def export(self):
return asdict(TraceEvents(traceEvents=self.events, otherData=self.otherData))

def save(self, f):
json.dump(self.export(), f)


if __name__ == '__main__':
emitter = TraceEventEmitter()
emitter.append('test', 1000)
with open('test.json', 'w') as ff:
emitter.save(ff)

0 comments on commit 1a16bec

Please sign in to comment.