diff --git a/examples/llm_pte_finetuning/TARGETS b/examples/llm_pte_finetuning/TARGETS new file mode 100644 index 00000000000..fee67914909 --- /dev/null +++ b/examples/llm_pte_finetuning/TARGETS @@ -0,0 +1,70 @@ +load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("papaya_oncall") + +python_library( + name = "model_loading_lib", + srcs = [ + "model_loading_lib.py", + ], + deps = [ + "fbcode//caffe2:torch", + "fbcode//executorch/examples/llm_pte_finetuning:training_lib", + "fbcode//executorch/exir:lib", + "fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader + "fbcode//pytorch/torchtune:lib", + "fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer + "fbsource//third-party/pypi/omegaconf:omegaconf", + "fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer + ], +) + +python_library( + name = "training_lib", + srcs = [ + "training_lib.py", + ], + deps = [ + "fbcode//caffe2:torch", + "fbcode//executorch/extension/pybindings:aten_lib", # @manual For PTE loader + "fbcode//pytorch/torchtune:lib", + "fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer + "fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer + "fbsource//third-party/pypi/tqdm:tqdm", + ], +) + +python_binary( + name = "runner", + srcs = [ + "runner.py", + ], + main_function = "executorch.examples.llm_pte_finetuning.runner.main", + deps = [ + "fbcode//caffe2:torch", + "fbcode//executorch/examples/llm_pte_finetuning:training_lib", + "fbcode//pytorch/torchtune:lib", + "fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer + "fbsource//third-party/pypi/omegaconf:omegaconf", + "fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer + "fbsource//third-party/pypi/tqdm:tqdm", + ], +) + +python_binary( + name = "model_exporter", + srcs = [ + "model_exporter.py", + ], + main_function = "executorch.examples.llm_pte_finetuning.model_exporter.main", + deps = [ + "fbcode//caffe2:torch", + "fbcode//executorch/examples/llm_pte_finetuning:model_loading_lib", # @manual for model loading + "fbcode//executorch/examples/llm_pte_finetuning:training_lib", # @manual for model exporting + "fbcode//pytorch/torchtune:lib", + "fbsource//third-party/pypi/blobfile:blobfile", # @manual For tokenizer + "fbsource//third-party/pypi/omegaconf:omegaconf", + "fbsource//third-party/pypi/tiktoken:tiktoken", # @manual For tokenizer + ], +) diff --git a/examples/llm_pte_finetuning/model_exporter.py b/examples/llm_pte_finetuning/model_exporter.py new file mode 100644 index 00000000000..e7f074c8769 --- /dev/null +++ b/examples/llm_pte_finetuning/model_exporter.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import argparse + +import torch +from executorch.examples.llm_pte_finetuning.model_loading_lib import ( + export_model_lora_training, + load_checkpoint, + setup_model, +) + +from executorch.examples.llm_pte_finetuning.training_lib import ( + get_dataloader, + TrainingModule, +) + +from omegaconf import OmegaConf +from torch.nn import functional as F +from torchtune import config + +from torchtune.training import MODEL_KEY + +parser = argparse.ArgumentParser( + prog="ModelExporter", + description="Export a LoRA model to ExecuTorch.", + epilog="Model exported to be used for fine-tuning.", +) + +parser.add_argument("--cfg", type=str, help="Path to the config file.") +parser.add_argument("--output_file", type=str, help="Path to the output ET model.") + + +def main() -> None: + args = parser.parse_args() + config_file = args.cfg + output_file = args.output_file + cfg = OmegaConf.load(config_file) + tokenizer = config.instantiate( + cfg.tokenizer, + ) + + loss_fn = config.instantiate(cfg.loss) + + ds = config.instantiate(cfg.dataset, tokenizer) + train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2]) + train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn) + + max_seq_len = cfg.tokenizer.max_seq_len + + # Example inputs, needed for ET export. + batch = next(iter(train_dataloader)) + tokens, labels = batch["tokens"], batch["labels"] + token_size = tokens.shape[1] + labels_size = labels.shape[1] + + if token_size > max_seq_len: + tokens = tokens[:, :max_seq_len] + else: + tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0) + + if labels_size > max_seq_len: + labels = labels[:, :max_seq_len] + else: + labels = F.pad(labels, (0, max_seq_len - labels_size), value=0) + + # Load pre-trained checkpoint. + checkpoint_dict = load_checkpoint(cfg=cfg) + model = setup_model( + # pyre-ignore + cfg=cfg, + base_model_state_dict=checkpoint_dict[MODEL_KEY], + ) + + training_module = TrainingModule(model, loss_fn) + + # Export the model to ExecuTorch for training. + export_model_lora_training(training_module, (tokens, labels), output_file) + + +if __name__ == "__main__": + main() diff --git a/examples/llm_pte_finetuning/model_loading_lib.py b/examples/llm_pte_finetuning/model_loading_lib.py new file mode 100644 index 00000000000..2edd2d27f5f --- /dev/null +++ b/examples/llm_pte_finetuning/model_loading_lib.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, Tuple + +import torch +from executorch.examples.llm_pte_finetuning.training_lib import TrainingModule +from executorch.exir import to_edge + +from omegaconf import DictConfig +from torch.export import export, ExportedProgram +from torch.export.experimental import _export_forward_backward +from torch.nn.attention import sdpa_kernel, SDPBackend +from torchtune import config +from torchtune.modules.peft import get_adapter_params, set_trainable_params +from torchtune.training.precision import get_dtype, set_default_dtype +from torchtune.utils._device import get_device + + +def load_checkpoint(cfg: Any) -> Dict[str, Any]: # pyre-ignore[2] + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + checkpointer = config.instantiate( + cfg.checkpointer, + resume_from_checkpoint=cfg.resume_from_checkpoint, + ) + checkpoint_dict = checkpointer.load_checkpoint() + return checkpoint_dict + + +def setup_model( + cfg: DictConfig, + base_model_state_dict: Dict[str, Any], +) -> torch.nn.Module: + device = get_device(device=cfg.device) + dtype = get_dtype(cfg.dtype, device=device) + with set_default_dtype(dtype), device: + model = config.instantiate(cfg.model) + + adapter_params = get_adapter_params(model) + set_trainable_params(model, adapter_params) + model.load_state_dict(base_model_state_dict, strict=False) + return model + + +def export_model_lora_training( + model: TrainingModule, example_args: Tuple[Any, ...], output_file: str # pyre-ignore[2] +) -> None: + """ + Export model with LoRA model to executorch for training, only. + """ + + # 0. Mark the LoRA layers as trainable (requires_grad = True) in order + # to just export the backwards pass for these layers later in the + # export process. + set_trainable_params(model, get_adapter_params(model)) + + print("Exporting model with LoRA for training") + # 1. torch.export: Defines the program with the ATen operator set. + + with sdpa_kernel([SDPBackend.MATH]): + exported_graph: ExportedProgram = export(model, example_args, strict=False) + print("Creating a joint forward-backwards graph for training") + joint_graph = _export_forward_backward(exported_graph) + + # 2. to_edge: Make optimizations for Edge devices. + print("Lowering to edge dialect") + edge_program = to_edge(joint_graph) + + print(edge_program._edge_programs["forward"].graph_module) + + # 3. to_executorch: Convert the graph to an ExecuTorch program. + print("Exporting to executorch") + executorch_program = edge_program.to_executorch() + print(executorch_program.exported_program().graph_signature) + print(f"Saving to {output_file}") + with open(output_file, "wb") as file: + file.write(executorch_program.buffer) diff --git a/examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml b/examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml new file mode 100644 index 00000000000..88e5bfac700 --- /dev/null +++ b/examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml @@ -0,0 +1,49 @@ +tokenizer: + _component_: torchtune.models.phi3.phi3_mini_tokenizer + path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model + max_seq_len: 1024 + +dataset: + _component_: torchtune.datasets.instruct_dataset + template: papaya.toolkit.experimental.llm_pte_finetuning.utils.DatabricksDolly + source: iamtarun/python_code_instructions_18k_alpaca + split: train + column_map: + instruction: instruction + prompt: prompt + input: input + output: output +seed: null +shuffle: True +batch_size: 1 + +loss: + _component_: torch.nn.CrossEntropyLoss + +model: + _component_: torchtune.models.phi3.lora_phi3_mini + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Phi-3-mini-4k-instruct + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Phi-3-mini-4k-instruct/ + model_type: PHI3_MINI + +resume_from_checkpoint: False +save_adapter_weights_only: False + +device: cpu +dtype: fp32 + +enable_activation_checkpointing: True +compile: False diff --git a/examples/llm_pte_finetuning/phi3_config.yaml b/examples/llm_pte_finetuning/phi3_config.yaml new file mode 100644 index 00000000000..7417ece79bd --- /dev/null +++ b/examples/llm_pte_finetuning/phi3_config.yaml @@ -0,0 +1,40 @@ +tokenizer: + _component_: torchtune.models.phi3.phi3_mini_tokenizer + path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model + max_seq_len: 512 + +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 1 + +loss: + _component_: torch.nn.CrossEntropyLoss + +model: + _component_: torchtune.models.phi3.lora_phi3_mini + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Phi-3-mini-4k-instruct + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Phi-3-mini-4k-instruct/ + model_type: PHI3_MINI +resume_from_checkpoint: False +save_adapter_weights_only: False + +device: cpu +dtype: fp32 + +enable_activation_checkpointing: True +compile: False diff --git a/examples/llm_pte_finetuning/qwen_05b_config.yaml b/examples/llm_pte_finetuning/qwen_05b_config.yaml new file mode 100644 index 00000000000..b93517b8fda --- /dev/null +++ b/examples/llm_pte_finetuning/qwen_05b_config.yaml @@ -0,0 +1,39 @@ +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + max_seq_len: 512 + +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 1 + +loss: + _component_: torch.nn.CrossEntropyLoss + +model: + _component_: torchtune.models.qwen2.lora_qwen2_0_5b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: False + lora_rank: 32 + lora_alpha: 64 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct + model_type: QWEN2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +device: cpu +dtype: fp32 + +enable_activation_checkpointing: True +compile: False diff --git a/examples/llm_pte_finetuning/runner.py b/examples/llm_pte_finetuning/runner.py new file mode 100644 index 00000000000..2e01fdafe8d --- /dev/null +++ b/examples/llm_pte_finetuning/runner.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import argparse + +import torch +from executorch.examples.llm_pte_finetuning.training_lib import ( + eval_model, + get_dataloader, + update_function, +) + +from executorch.extension.pybindings.aten_lib import ( # @manual + _load_for_executorch_from_buffer, +) +from omegaconf import OmegaConf +from torch.nn import functional as F +from torchtune import config +from tqdm import tqdm + +parser = argparse.ArgumentParser( + prog="Runner", + description="Fine tunes LoRA model using ExecuTorch.", + epilog="Model exported to be used for fine-tuning.", +) +parser.add_argument("--cfg", type=str, help="Path to the config file.") +parser.add_argument("--model_file", type=str, help="Path to the ET model file.") + + +def main() -> None: + args = parser.parse_args() + config_file = args.cfg + file = args.model_file + cfg = OmegaConf.load(config_file) + tokenizer = config.instantiate( + cfg.tokenizer, + ) + + loss_fn = config.instantiate(cfg.loss) + + ds = config.instantiate(cfg.dataset, tokenizer) + train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2]) + train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn) + val_dataloader = get_dataloader(cfg, val_set, tokenizer, loss_fn) + + max_seq_len = cfg.tokenizer.max_seq_len + # Num of steps to run training. Assume 1 epoch + num_steps = 100 + with open(file, "rb") as f: + model_bytes = f.read() + et_mod = _load_for_executorch_from_buffer(model_bytes) + + # Evaluate the model before training. + print("Evaluating the model before training") + eval_loss = eval_model( + model=et_mod, + dataloader=val_dataloader, + loss_fn=loss_fn, + max_seq_len=max_seq_len, + num_eval_steps=10, + ) + print("Eval loss: ", eval_loss) + + # Based on executorch/extension/training/module/training_module.cpp + # grads run from [grad_start, param_start] + # params run from [param_start, outputs_end] + grad_start = et_mod.run_method("__et_training_gradients_index_forward", [])[0] + param_start = et_mod.run_method("__et_training_parameters_index_forward", [])[0] + learning_rate = 5e-3 + f.seek(0) + losses = [] + for i, batch in tqdm(enumerate(train_dataloader), total=num_steps): + # Run for a limited number of steps. + if i >= num_steps: + break + tokens, labels = batch["tokens"], batch["labels"] + token_size = tokens.shape[1] + labels_size = labels.shape[1] + + # Fixed length for now. We need to resize as the input shapes + # should be the same passed as examples to the export function. + if token_size > max_seq_len: + tokens = tokens[:, :max_seq_len] + else: + tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0) + + if labels_size > max_seq_len: + labels = labels[:, :max_seq_len] + else: + labels = F.pad(labels, (0, max_seq_len - labels_size), value=0) + + out = et_mod.forward((tokens, labels)) + + loss = out[0] + losses.append(loss.item()) + with torch.no_grad(): + for grad, param in zip(out[grad_start:param_start], out[param_start:]): + update_function(param, grad, learning_rate) + + print("Losses: ", losses) + # Evaluate the model after training. + eval_loss = eval_model( + model=et_mod, + dataloader=val_dataloader, + loss_fn=loss_fn, + max_seq_len=max_seq_len, + num_eval_steps=10, + ) + print("Eval loss: ", eval_loss) + + +if __name__ == "__main__": + main() diff --git a/examples/llm_pte_finetuning/training_lib.py b/examples/llm_pte_finetuning/training_lib.py new file mode 100644 index 00000000000..e43154937e8 --- /dev/null +++ b/examples/llm_pte_finetuning/training_lib.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from functools import partial +from typing import Any, Dict, Mapping, Optional + +import torch +from executorch.extension.pybindings.aten_lib import ExecuTorchModule # @manual + +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from torchtune.data import InstructTemplate +from torchtune.data._collate import padded_collate_sft +from tqdm import tqdm + + +class TrainingModule(torch.nn.Module): + """ + The model being trained should return the loss from forward(). This + class wraps the actual model and computes the loss for an LLM + fine-tuning task. The loss is computed as the cross entropy between + the tokens and a shifted version of the labels so we learn to predict + the next token. + """ + + def __init__( + self, model: torch.nn.Module, loss: torch.nn.modules.loss._Loss + ) -> None: + super().__init__() + self.model = model + self.loss = loss + + def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + # Output is of the shape (seq_len, vocab_size). + logits = self.model(input) + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + return self.loss(logits, labels) + + +class DatabricksDolly(InstructTemplate): + """ + Used for the Dolly dataset from Databricks. + + https://huggingface.co/datasets/databricks/databricks-dolly-15k + """ + + template = "Instruction:\n{instruction}\n\nContext:\n{input}\n\nResponse: " + + @classmethod + def format( + cls, + sample: Mapping[str, Any], + column_map: Optional[Dict[str, str]], + ) -> str: + assert column_map is not None + instruction = sample[column_map["instruction"]] + input = sample[column_map["input"]] + return cls.template.format(instruction=instruction, input=input) + + +class PythonCodeInstructions(InstructTemplate): + """ + https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca + """ + + template = ( + "{prompt}\n\n" + "Instruction:\n{instruction}" + "\n\nContext:\n{input}\n\nResponse: " + ) + + @classmethod + def format( + cls, + sample: Mapping[str, Any], + column_map: Optional[Dict[str, str]], + ) -> str: + assert column_map is not None + instruction = sample[column_map["instruction"]] + input = sample[column_map["input"]] + prompt = sample[column_map["prompt"]] + return cls.template.format(instruction=instruction, input=input, prompt=prompt) + + +def update_function( + param: torch.Tensor, + grad: torch.Tensor, + learning_rate: float, + weight_decay: float = 1.0, +) -> None: + """SGD update function.""" + grad = grad + weight_decay * param + param.sub_(learning_rate * grad) + + +def eval_model( + model: ExecuTorchModule, + dataloader: DataLoader, + loss_fn: torch.nn.modules.loss._Loss, + max_seq_len: int, + num_eval_steps: int, +) -> float: + total_loss = 0 + for i, batch in tqdm(enumerate(dataloader), total=num_eval_steps): + if i >= num_eval_steps: + break + tokens, labels = batch["tokens"], batch["labels"] + token_size = tokens.shape[1] + labels_size = labels.shape[1] + + tokens, labels = batch["tokens"], batch["labels"] + token_size = tokens.shape[1] + labels_size = labels.shape[1] + + # Fixed length for now. We need to resize as the input shapes + # should be the same passed as examples to the export function. + if token_size > max_seq_len: + tokens = tokens[:, :max_seq_len] + else: + tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0) + + if labels_size > max_seq_len: + labels = labels[:, :max_seq_len] + else: + labels = F.pad(labels, (0, max_seq_len - labels_size), value=0) + + out = model.forward((tokens, labels), True) + loss = out[0] + total_loss += loss + return total_loss / num_eval_steps + + +def get_dataloader( + cfg: Any, ds: Dataset[Any], tokenizer: Any, loss_fn: torch.nn.modules.loss._Loss # pyre-ignore[2] +) -> DataLoader: + """Given a dataset, tokenizer, and loss function, return a dataloader.""" + packed = cfg.dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=cfg.shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=cfg.batch_size, + collate_fn=( + partial( + padded_collate_sft, + padding_idx=tokenizer.pad_id, + ignore_idx=loss_fn.ignore_index, + ) + if not packed + else None + ), + ) + return dataloader diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index 57bc44d1394..d674f2fe58c 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -509,7 +509,8 @@ struct PyModule final { py::list run_method( const std::string& method_name, - const py::sequence& inputs) { + const py::sequence& inputs, + bool clone_outputs = true) { const auto inputs_size = py::len(inputs); std::vector cpp_inputs; cpp_inputs.reserve(inputs_size); @@ -603,17 +604,19 @@ struct PyModule final { module_->run_method(method_name, cpp_inputs, output_storage_spans); // Retrieve outputs - return get_outputs_as_py_list(outputs); + return get_outputs_as_py_list(outputs, clone_outputs); } - py::list forward(const py::sequence& inputs) { - return run_method("forward", inputs); + py::list forward(const py::sequence& inputs, bool clone_outputs = true) { + return run_method("forward", inputs, clone_outputs); } - py::list forward_single_input(const torch::Tensor& inputTensor) { + py::list forward_single_input( + const torch::Tensor& inputTensor, + bool clone_outputs = true) { py::list py_list; py_list.append(py::cast(inputTensor)); - return run_method("forward", py_list); + return run_method("forward", py_list, clone_outputs); } bool has_etdump() { @@ -686,7 +689,9 @@ struct PyModule final { return outputs; } - py::list plan_execute(const std::string method_name) { + py::list plan_execute( + const std::string method_name, + bool clone_outputs = true) { auto& method = module_->get_method(method_name); // Need to pre-allocate space for outputs just like in run_method. const auto num_outputs = method.outputs_size(); @@ -703,10 +708,12 @@ struct PyModule final { "executing execution plan for method 'forward' failed with error: 0x%" PRIx32, static_cast(status)); const auto outputs = module_->get_outputs(method_name); - return get_outputs_as_py_list(outputs); + return get_outputs_as_py_list(outputs, clone_outputs); } - py::list get_outputs_as_py_list(const std::vector& outputs) { + py::list get_outputs_as_py_list( + const std::vector& outputs, + bool clone_outputs = true) { const auto outputs_size = outputs.size(); py::list list(outputs_size); for (size_t i = 0; i < outputs_size; ++i) { @@ -725,9 +732,17 @@ struct PyModule final { #ifdef USE_ATEN_LIB // Clone so the outputs in python do not share a lifetime with the // module object - list[i] = py::cast(v.toTensor().clone()); + if (clone_outputs) { + list[i] = py::cast(v.toTensor().clone()); + } else { + list[i] = py::cast(v.toTensor()); + } #else - list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()).clone()); + if (clone_outputs) { + list[i] = py::cast(alias_attensor_to_etensor(v.toTensor()).clone()); + } else { + list[i] = py::cast(alias_attensor_to_etensor(v.toTensor())); + } #endif } else { ET_ASSERT_UNREACHABLE_MSG("Invalid model output type"); @@ -845,14 +860,25 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { py::arg("rtol") = 1e-5, py::arg("atol") = 1e-8, call_guard) - .def("plan_execute", &PyModule::plan_execute, call_guard) + .def( + "plan_execute", + &PyModule::plan_execute, + py::arg("method_name"), + py::arg("clone_outputs") = true, + call_guard) .def( "run_method", &PyModule::run_method, py::arg("method_name"), py::arg("inputs") = py::list(), + py::arg("clone_outputs") = true, + call_guard) + .def( + "forward", + &PyModule::forward, + py::arg("inputs") = py::list(), + py::arg("clone_outputs") = true, call_guard) - .def("forward", &PyModule::forward, call_guard) .def("has_etdump", &PyModule::has_etdump, call_guard) .def( "write_etdump_result_to_file", @@ -860,8 +886,18 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) { py::arg("path"), py::arg("debug_buffer_path") = py::none(), call_guard) - .def("__call__", &PyModule::forward, call_guard) - .def("__call__", &PyModule::forward_single_input, call_guard); + .def( + "__call__", + &PyModule::forward, + py::arg("inputs") = py::list(), + py::arg("clone_outputs") = true, + call_guard) + .def( + "__call__", + &PyModule::forward_single_input, + py::arg("inputs") = py::list(), + py::arg("clone_outputs") = true, + call_guard); py::class_(m, "BundledModule"); } diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index e63863fc048..cba03b8a743 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -410,3 +410,9 @@ - op: zeros_like.out - op: zeros.out + +- op: gather.out + +- op: scatter.value_out + +- op: aten::native_dropout.out