Skip to content

Commit 13a543a

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
IOManager Interface (#10418)
Summary: Hopefully this is sufficient for the contract. Going to do 2 follow up tests. Add a basic cpu implementation add a static attention implementation. Differential Revision: D73450877
1 parent 51befee commit 13a543a

File tree

3 files changed

+141
-0
lines changed

3 files changed

+141
-0
lines changed
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/extension/tensor/tensor.h>
12+
#include <executorch/runtime/executor/method_meta.h>
13+
#include <executorch/runtime/executor/method.h>
14+
15+
namespace executorch {
16+
namespace extension {
17+
namespace llm {
18+
19+
/**
20+
* @brief Base class for managing input/output operations for LLM inference.
21+
*
22+
* IOManagerBase provides an interface for handling the input preparation and
23+
* output processing for both prefill and decode phases of LLM inference.
24+
* Derived classes must implement the virtual methods to provide specific IO
25+
* management functionality.
26+
*/
27+
class ET_EXPERIMENTAL IOManagerBase {
28+
public:
29+
/**
30+
* @brief Virtual destructor to allow proper cleanup in derived classes.
31+
*/
32+
virtual ~IOManagerBase() = default;
33+
34+
/**
35+
* @brief Initialize the IO manager with method metadata for prefill and
36+
* decode operations.
37+
*
38+
* @param prefill_method The prefill method to initialize with.
39+
* @param decode_method The decode method to initialize with.
40+
*/
41+
ET_NODISCARD virtual runtime::Error init(
42+
executorch::runtime::Method& prefill_method,
43+
executorch::runtime::Method& decode_method) = 0;
44+
45+
/**
46+
* @brief Reset the IO manager state.
47+
*
48+
* @param prefill_method The prefill method to reset with.
49+
* @param decode_method The decode method to reset with.
50+
*/
51+
ET_NODISCARD virtual runtime::Error reset(
52+
executorch::runtime::Method& prefill_method,
53+
executorch::runtime::Method& decode_method) = 0;
54+
55+
/**
56+
* @brief Prepare inputs for the prefill phase of LLM inference.
57+
*
58+
* @param input The input tensor containing token IDs.
59+
* @param start_pos The tensor containing the starting position of the current
60+
* input within the context.
61+
* @param prefill_method The prefill method to prepare inputs for.
62+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
63+
* for the prefill method.
64+
*/
65+
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
66+
prepare_prefill(
67+
const executorch::extension::TensorPtr& input,
68+
const executorch::extension::TensorPtr& start_pos,
69+
executorch::runtime::Method& prefill_method) = 0;
70+
71+
/**
72+
* @brief Prepare inputs for the decode phase of LLM inference.
73+
*
74+
* @param input The input tensor containing token IDs.
75+
* @param start_pos The tensor containing the starting position of the current
76+
* input within the context.
77+
* @param decode_method The decode method to prepare inputs for.
78+
* @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
79+
* for the decode method.
80+
*/
81+
virtual runtime::Result<std::vector<executorch::runtime::EValue>>
82+
prepare_decode(
83+
const executorch::extension::TensorPtr& input,
84+
const executorch::extension::TensorPtr& start_pos,
85+
executorch::runtime::Method& decode_method) = 0;
86+
87+
/**
88+
* @brief Process and update internal state with outputs from the prefill
89+
* phase.
90+
*
91+
* @param prefill_method The prefill method to update with outputs.
92+
* @param model_outputs Vector of outputs from the prefill method execution.
93+
*/
94+
ET_NODISCARD virtual runtime::Error update_prefill(
95+
executorch::runtime::Method& prefill_method,
96+
const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
97+
98+
/**
99+
* @brief Process and update internal state with outputs from the decode
100+
* phase.
101+
*
102+
* @param decode_method The decode method to update with outputs.
103+
* @param model_outputs Vector of outputs from the decode method execution.
104+
*/
105+
ET_NODISCARD virtual runtime::Error update_decode(
106+
const executorch::runtime::Method& decode_method,
107+
const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
108+
};
109+
110+
} // namespace llm
111+
} // namespace extension
112+
} // namespace executorch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
for aten in (True, False):
5+
aten_suffix = "_aten" if aten else ""
6+
7+
# Interface for IOManager. No concrete impl from this dep.
8+
runtime.cxx_library(
9+
name = "io_manager" + aten_suffix,
10+
exported_headers = [
11+
"io_manager.h",
12+
],
13+
deps = [
14+
"//executorch/extension/module:module" + aten_suffix,
15+
"//executorch/extension/tensor:tensor" + aten_suffix,
16+
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,
17+
],
18+
visibility = [
19+
"@EXECUTORCH_CLIENTS",
20+
],
21+
)

0 commit comments

Comments
 (0)