Skip to content

Commit d3fe979

Browse files
authored
Repo sync (#594)
1 parent 9553904 commit d3fe979

File tree

115 files changed

+3690
-1929
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

115 files changed

+3690
-1929
lines changed

.circleci/release-config.yml

-40
Original file line numberDiff line numberDiff line change
@@ -30,38 +30,6 @@ parameters:
3030
# Define a job to be invoked later in a workflow.
3131
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
3232
jobs:
33-
macOS_x64_publish:
34-
macos:
35-
xcode: 15.1
36-
resource_class: macos.x86.medium.gen2
37-
parameters:
38-
python_ver:
39-
type: string
40-
steps:
41-
- checkout
42-
- run:
43-
name: "Install homebrew dependencies"
44-
command: |
45-
brew install bazelisk cmake ninja nasm libomp wget go
46-
- run:
47-
name: "Install Miniconda"
48-
command: |
49-
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O ~/miniconda.sh
50-
bash ~/miniconda.sh -b -p $HOME/miniconda
51-
source $HOME/miniconda/bin/activate
52-
conda init zsh bash
53-
- run:
54-
name: "build package and publish"
55-
command: |
56-
set +e
57-
conda create -n build python=<< parameters.python_ver >> -y
58-
conda activate build
59-
60-
sh ./build_wheel_entrypoint.sh
61-
python3 -m pip install twine
62-
ls dist/*.whl
63-
64-
python3 -m twine upload -r pypi -u __token__ -p ${PYPI_TWINE_TOKEN} dist/*.whl
6533
macOS_arm64_publish:
6634
macos:
6735
xcode: 15.1
@@ -158,14 +126,6 @@ workflows:
158126
filters:
159127
tags:
160128
only: /.*/
161-
- macOS_x64_publish:
162-
matrix:
163-
parameters:
164-
python_ver: ["3.9", "3.10", "3.11"]
165-
# This is mandatory to trigger a pipeline when pushing a tag
166-
filters:
167-
tags:
168-
only: /.*/
169129
- macOS_arm64_publish:
170130
matrix:
171131
parameters:

CHANGELOG.md

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
>
1111
> please add your unreleased change here.
1212
13+
- [Feature] Add minimax approximation for log
14+
- [Feature] Support jax.lax.top_k
15+
- [Improvement] Default log approximation to minmax
16+
- [Improvement] Improve median performance
17+
18+
## 20240306
19+
1320
- [Feature] Support more generic Torch model inference
1421
- [Improvement] Optimize one-time setup for yacl ot
1522
- [Improvement] Optimize sort performance

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ This documentation also contains instructions for [build and testing](CONTRIBUTI
3434
| | Linux x86_64 | Linux aarch64 | macOS x64 | macOS Apple Silicon | Windows x64 | Windows WSL2 x64 |
3535
|------------|--------------|---------------|----------------|---------------------|----------------|---------------------|
3636
| CPU | yes | yes | yes<sup>1</sup>| yes | no | yes |
37-
| NVIDIA GPU | experimental | no | no | n/a | no | experimental |
37+
| NVIDIA GPU | experimental | no | no | n/a | no | experimental |
3838

39-
1. Due to CI resource limitation, macOS x64 prebuild binary will no longer available since next release (0.9.x).
39+
1. Due to CI resource limitation, macOS x64 prebuild binary is no longer available.
4040

4141
### Instructions
4242

bazel/patches/seal.patch

+15-1
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,18 @@ index dabd3bab..afaa71dc 100644
218218
+ else
219219
{
220220
inverse_ntt_negacyclic_harvey_lazy(get<0, 1>(J), get<2>(J));
221-
}
221+
}
222+
223+
diff --git a/CMakeLists.txt b/CMakeLists.txt
224+
index 1a7a2bfd..bc4ad9d9 100644
225+
--- a/CMakeLists.txt
226+
+++ b/CMakeLists.txt
227+
@@ -223,7 +223,7 @@ if(SEAL_USE_INTEL_HEXL)
228+
message(STATUS "Intel HEXL: download ...")
229+
seal_fetch_thirdparty_content(ExternalIntelHEXL)
230+
else()
231+
- find_package(HEXL 1.2.4)
232+
+ find_package(HEXL 1.2.5)
233+
if (NOT TARGET HEXL::hexl)
234+
message(FATAL_ERROR "Intel HEXL: not found")
235+
endif()

bazel/repositories.bzl

+2-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ def _com_github_xtensor_xtl():
140140
)
141141

142142
def _com_github_openxla_xla():
143-
OPENXLA_COMMIT = "d1cf2382e57b1efba3bb17d6dd9d8657453405ca"
144-
OPENXLA_SHA256 = "a7f439d54a4e35c7977c2ea17b3a2493b306c9629ccc8071b4962c905ac9f692"
143+
OPENXLA_COMMIT = "495516d2d0b4453d5831905e152594614c8b4797"
144+
OPENXLA_SHA256 = "13f6490065db594c6a7f9914e59213b6785ceb81af1f2cb28d5409f3f18aac8e"
145145

146146
maybe(
147147
http_archive,

bazel/seal.BUILD

+8-2
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,17 @@ x64_hexl_config = {
5353

5454
spu_cmake_external(
5555
name = "seal",
56-
cache_entries = default_config,
56+
cache_entries = select({
57+
":can_use_hexl": x64_hexl_config,
58+
"//conditions:default": default_config,
59+
}),
5760
lib_source = "@com_github_microsoft_seal//:all",
5861
out_include_dir = "include/SEAL-4.1",
5962
out_static_libs = ["libseal-4.1.a"],
6063
deps = [
6164
"@com_github_facebook_zstd//:zstd",
62-
],
65+
] + select({
66+
"@platforms//cpu:x86_64": ["@com_intel_hexl//:hexl"],
67+
"//conditions:default": [],
68+
}),
6369
)

examples/python/ml/flax_llama7b_split/flax_llama7b_split.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,35 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import argparse
16+
import json
17+
1518
# Start nodes.
1619
# > bazel run -c opt //examples/python/utils:nodectl -- --config `pwd`/examples/python/ml/flax_llama_split/3pc.json" up
1720
# Run this example script.
1821
# > bazel run -c opt //examples/python/ml/flax_llama7b -- --config `pwd`/examples/python/ml/flax_llama_split/3pc.json
1922
import time
20-
import argparse
21-
import json
23+
from contextlib import contextmanager
24+
from typing import Any, Optional, Tuple, Union
25+
26+
import flax.linen as nn
2227
import jax
23-
import jax.numpy as jnp
2428
import jax.nn as jnn
25-
import flax.linen as nn
26-
from flax.linen.linear import Array
27-
from typing import Any, Optional, Tuple, Union
28-
from transformers import LlamaTokenizer
29+
import jax.numpy as jnp
2930
from EasyLM.checkpoint import StreamingCheckpointer
3031
from EasyLM.models.llama.llama_model import FlaxLLaMAForCausalLM
3132
from EasyLM.models.llama.llama_model_splited_transformer import (
3233
FlaxLLaMAForCausalLMClient,
34+
FlaxLLaMAForCausalLMMid,
3335
FlaxLLaMAForCausalLMServer,
3436
FlaxLLaMAModule,
35-
FlaxLLaMAForCausalLMMid,
3637
LLaMAConfig,
3738
)
39+
from flax.linen.linear import Array
40+
from transformers import LlamaTokenizer
3841

39-
40-
import spu.utils.distributed as ppd
41-
from contextlib import contextmanager
4242
import spu.spu_pb2 as spu_pb2
43-
44-
from flax.linen.linear import Array
45-
from typing import Any, Optional, Tuple, Union
43+
import spu.utils.distributed as ppd
4644

4745
parser = argparse.ArgumentParser(description='distributed driver.')
4846
parser.add_argument(

examples/python/ml/flax_llama7b_split/llama_model_splited_transformer.py

+21-27
Original file line numberDiff line numberDiff line change
@@ -16,55 +16,49 @@
1616
# Original Source Code Form
1717
# [EasyLM](https://github.com/young-geng/EasyLM/tree/main)
1818

19-
import os
20-
from shutil import copyfile
21-
from typing import Any, Dict, List, Optional, Tuple, Union
2219
import json
20+
import os
2321
import tempfile
2422
from functools import partial
25-
from jax import jit
26-
import numpy as np
23+
from shutil import copyfile
24+
from typing import Any, Dict, List, Optional, Tuple, Union
25+
26+
import einops
27+
import flax.linen as nn
2728
import jax
2829
import jax.numpy as jnp
29-
from jax import lax
30-
from jax.sharding import PartitionSpec as PS
31-
import flax.linen as nn
30+
import numpy as np
31+
import sentencepiece as spm
32+
from EasyLM.bpt import blockwise_attn, blockwise_ffn
33+
from EasyLM.jax_utils import (
34+
get_gradient_checkpoint_policy,
35+
get_jax_mesh,
36+
with_sharding_constraint,
37+
)
3238
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
3339
from flax.linen import combine_masks, make_causal_mask
40+
from flax.linen import partitioning as nn_partitioning
3441
from flax.linen.attention import dot_product_attention_weights
3542
from flax.traverse_util import flatten_dict, unflatten_dict
36-
from flax.linen import partitioning as nn_partitioning
37-
import einops
38-
39-
import sentencepiece as spm
43+
from jax import jit, lax
44+
from jax.sharding import PartitionSpec as PS
45+
from ml_collections import ConfigDict
46+
from ml_collections.config_dict import config_dict
47+
from mlxu import function_args_to_config, load_pickle, open_file
4048
from transformers.configuration_utils import PretrainedConfig
41-
from transformers.utils import logging
42-
from transformers.tokenization_utils import PreTrainedTokenizer
4349
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
4450
from transformers.modeling_flax_utils import (
4551
ACT2FN,
4652
FlaxPreTrainedModel,
4753
append_call_sample_docstring,
4854
)
55+
from transformers.tokenization_utils import PreTrainedTokenizer
4956
from transformers.utils import (
5057
add_start_docstrings,
5158
add_start_docstrings_to_model_forward,
5259
logging,
5360
)
5461

55-
56-
from ml_collections import ConfigDict
57-
from ml_collections.config_dict import config_dict
58-
from mlxu import function_args_to_config, load_pickle, open_file
59-
60-
from EasyLM.bpt import blockwise_ffn, blockwise_attn
61-
from EasyLM.jax_utils import (
62-
with_sharding_constraint,
63-
get_jax_mesh,
64-
get_gradient_checkpoint_policy,
65-
)
66-
67-
6862
LLAMA_STANDARD_CONFIGS = {
6963
'7b': {
7064
'vocab_size': 32000,

examples/python/ml/flax_resnet/flax_resnet_inference.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
import json
1717

1818
import jax
19-
import spu.utils.distributed as ppd
2019

20+
import spu.utils.distributed as ppd
2121

2222
parser = argparse.ArgumentParser(description='distributed driver.')
2323
parser.add_argument("-c", "--config", default="3pc.json")
@@ -29,14 +29,13 @@
2929
ppd.init(conf["nodes"], conf["devices"])
3030

3131

32+
from datasets import load_dataset
3233
from transformers import (
33-
AutoImageProcessor,
3434
AutoConfig,
35+
AutoImageProcessor,
3536
FlaxResNetForImageClassification,
3637
)
3738

38-
from datasets import load_dataset
39-
4039
dataset = load_dataset("huggingface/cats-image")
4140
image = dataset["test"]["image"][0]
4241

examples/python/ml/flax_whisper/flax_whisper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import os
2424

2525
import jax.numpy as jnp
26-
from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
2726
from datasets import load_dataset
27+
from transformers import FlaxWhisperForConditionalGeneration, WhisperProcessor
2828

2929
import spu.utils.distributed as ppd
3030
from spu import spu_pb2

examples/python/ml/ml_test.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
import unittest
2222
from time import perf_counter
2323

24-
import multiprocess
2524
import numpy.testing as npt
2625
import pandas as pd
2726

2827
import spu.utils.distributed as ppd
28+
from spu.utils.polyfill import Process
2929

3030
with open("examples/python/conf/3pc.json", 'r') as file:
3131
conf = json.load(file)
@@ -70,15 +70,13 @@ class UnitTests(unittest.TestCase):
7070
def setUpClass(cls):
7171
cls.workers = []
7272
for node_id in conf["nodes"].keys():
73-
worker = multiprocess.Process(
74-
target=ppd.RPC.serve, args=(node_id, conf["nodes"])
75-
)
73+
worker = Process(target=ppd.RPC.serve, args=(node_id, conf["nodes"]))
7674
worker.start()
7775
cls.workers.append(worker)
7876
import time
7977

8078
# wait for all process serving.
81-
time.sleep(0.05)
79+
time.sleep(2)
8280

8381
rt_config = conf["devices"]["SPU"]["config"]["runtime_config"]
8482
rt_config["enable_pphlo_profile"] = False

examples/python/ml/torch_lr_experiment/torch_lr_experiment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import spu.utils.distributed as ppd
2424

25-
2625
# Start nodes.
2726
# > bazel run -c opt //examples/python/utils:nodectl -- up
2827
#
@@ -114,6 +113,7 @@ def run_inference_on_cpu(model):
114113
ppd.init(conf["nodes"], conf["devices"], framework=ppd.Framework.EXP_TORCH)
115114

116115
from collections import OrderedDict
116+
117117
from jax.tree_util import tree_map
118118

119119

examples/python/utils/nodectl.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
import argparse
1616
import json
1717

18-
import multiprocess
19-
2018
import spu.utils.distributed as ppd
19+
from spu.utils.polyfill import Process
2120

2221
parser = argparse.ArgumentParser(description='SPU node service.')
2322
parser.add_argument(
@@ -44,9 +43,7 @@
4443
elif args.command == 'up':
4544
workers = []
4645
for node_id in nodes_def.keys():
47-
worker = multiprocess.Process(
48-
target=ppd.RPC.serve, args=(node_id, nodes_def)
49-
)
46+
worker = Process(target=ppd.RPC.serve, args=(node_id, nodes_def))
5047
worker.start()
5148
workers.append(worker)
5249

libspu/compiler/core/core.cc

+4
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ void Core::buildPipeline(mlir::PassManager *pm) {
5555
optPM.addPass(mlir::spu::pphlo::createDecomposeMinMaxPass());
5656
optPM.addPass(mlir::spu::pphlo::createSortLowering());
5757

58+
if (!options.disable_partial_sort_optimization()) {
59+
optPM.addPass(mlir::spu::pphlo::createPartialSortToTopK());
60+
}
61+
5862
if (!options.disable_sqrt_plus_epsilon_rewrite()) {
5963
optPM.addPass(mlir::spu::pphlo::createOptimizeSqrtPlusEps());
6064
}

0 commit comments

Comments
 (0)