1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ * All rights reserved.
4
+ *
5
+ * This source code is licensed under the BSD-style license found in the
6
+ * LICENSE file in the root directory of this source tree.
7
+ */
8
+
9
+ #pragma once
10
+
11
+ #include < executorch/extension/tensor/tensor.h>
12
+ #include < executorch/runtime/executor/method_meta.h>
13
+ #include < runtime/executor/method.h>
14
+
15
+ namespace executorch {
16
+ namespace extension {
17
+ namespace llm {
18
+
19
+ /* *
20
+ * @brief Base class for managing input/output operations for LLM inference.
21
+ *
22
+ * IOManagerBase provides an interface for handling the input preparation and
23
+ * output processing for both prefill and decode phases of LLM inference.
24
+ * Derived classes must implement the virtual methods to provide specific IO
25
+ * management functionality.
26
+ */
27
+ class ET_EXPERIMENTAL IOManagerBase {
28
+ public:
29
+ /* *
30
+ * @brief Virtual destructor to allow proper cleanup in derived classes.
31
+ */
32
+ ET_EXPERIMENTAL virtual ~IOManagerBase () = default ;
33
+
34
+ /* *
35
+ * @brief Initialize the IO manager with method metadata for prefill and
36
+ * decode operations.
37
+ *
38
+ * @param prefill_method The prefill method to initialize with.
39
+ * @param decode_method The decode method to initialize with.
40
+ */
41
+ ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error init (
42
+ executorch::runtime::Method& prefill_method,
43
+ executorch::runtime::Method& decode_method) = 0;
44
+
45
+ /* *
46
+ * @brief Reset the IO manager state.
47
+ *
48
+ * @param prefill_method The prefill method to reset with.
49
+ * @param decode_method The decode method to reset with.
50
+ */
51
+ ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error reset (
52
+ executorch::runtime::Method& prefill_method,
53
+ executorch::runtime::Method& decode_method) = 0;
54
+
55
+ /* *
56
+ * @brief Prepare inputs for the prefill phase of LLM inference.
57
+ *
58
+ * @param input The input tensor containing token IDs.
59
+ * @param start_pos The tensor containing the starting position of the current
60
+ * input within the context.
61
+ * @param prefill_method The prefill method to prepare inputs for.
62
+ * @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
63
+ * for the prefill method.
64
+ */
65
+ ET_EXPERIMENTAL virtual runtime::Result<
66
+ std::vector<executorch::runtime::EValue>>
67
+ prepare_prefill (
68
+ const executorch::extension::TensorPtr& input,
69
+ const executorch::extension::TensorPtr& start_pos,
70
+ executorch::runtime::Method& prefill_method) = 0 ;
71
+
72
+ /* *
73
+ * @brief Prepare inputs for the decode phase of LLM inference.
74
+ *
75
+ * @param input The input tensor containing token IDs.
76
+ * @param start_pos The tensor containing the starting position of the current
77
+ * input within the context.
78
+ * @param decode_method The decode method to prepare inputs for.
79
+ * @return std::vector<executorch::runtime::EValue> Vector of prepared inputs
80
+ * for the decode method.
81
+ */
82
+ ET_EXPERIMENTAL virtual runtime::Result<
83
+ std::vector<executorch::runtime::EValue>>
84
+ prepare_decode (
85
+ const executorch::extension::TensorPtr& input,
86
+ const executorch::extension::TensorPtr& start_pos,
87
+ executorch::runtime::Method& decode_method) = 0 ;
88
+
89
+ /* *
90
+ * @brief Process and update internal state with outputs from the prefill
91
+ * phase.
92
+ *
93
+ * @param prefill_method The prefill method to update with outputs.
94
+ * @param model_outputs Vector of outputs from the prefill method execution.
95
+ */
96
+ ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error update_prefill (
97
+ executorch::runtime::Method& prefill_method,
98
+ const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
99
+
100
+ /* *
101
+ * @brief Process and update internal state with outputs from the decode
102
+ * phase.
103
+ *
104
+ * @param decode_method The decode method to update with outputs.
105
+ * @param model_outputs Vector of outputs from the decode method execution.
106
+ */
107
+ ET_EXPERIMENTAL ET_NODISCARD virtual runtime::Error update_decode (
108
+ const executorch::runtime::Method& decode_method,
109
+ const std::vector<executorch::runtime::EValue>& model_outputs) = 0;
110
+ };
111
+
112
+ } // namespace llm
113
+ } // namespace extension
114
+ } // namespace executorch
115
+
0 commit comments