|
16 | 16 | # Original Source Code Form
|
17 | 17 | # [EasyLM](https://github.com/young-geng/EasyLM/tree/main)
|
18 | 18 |
|
19 |
| -import os |
20 |
| -from shutil import copyfile |
21 |
| -from typing import Any, Dict, List, Optional, Tuple, Union |
22 | 19 | import json
|
| 20 | +import os |
23 | 21 | import tempfile
|
24 | 22 | from functools import partial
|
25 |
| -from jax import jit |
26 |
| -import numpy as np |
| 23 | +from shutil import copyfile |
| 24 | +from typing import Any, Dict, List, Optional, Tuple, Union |
| 25 | + |
| 26 | +import einops |
| 27 | +import flax.linen as nn |
27 | 28 | import jax
|
28 | 29 | import jax.numpy as jnp
|
29 |
| -from jax import lax |
30 |
| -from jax.sharding import PartitionSpec as PS |
31 |
| -import flax.linen as nn |
| 30 | +import numpy as np |
| 31 | +import sentencepiece as spm |
| 32 | +from EasyLM.bpt import blockwise_attn, blockwise_ffn |
| 33 | +from EasyLM.jax_utils import ( |
| 34 | + get_gradient_checkpoint_policy, |
| 35 | + get_jax_mesh, |
| 36 | + with_sharding_constraint, |
| 37 | +) |
32 | 38 | from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
33 | 39 | from flax.linen import combine_masks, make_causal_mask
|
| 40 | +from flax.linen import partitioning as nn_partitioning |
34 | 41 | from flax.linen.attention import dot_product_attention_weights
|
35 | 42 | from flax.traverse_util import flatten_dict, unflatten_dict
|
36 |
| -from flax.linen import partitioning as nn_partitioning |
37 |
| -import einops |
38 |
| - |
39 |
| -import sentencepiece as spm |
| 43 | +from jax import jit, lax |
| 44 | +from jax.sharding import PartitionSpec as PS |
| 45 | +from ml_collections import ConfigDict |
| 46 | +from ml_collections.config_dict import config_dict |
| 47 | +from mlxu import function_args_to_config, load_pickle, open_file |
40 | 48 | from transformers.configuration_utils import PretrainedConfig
|
41 |
| -from transformers.utils import logging |
42 |
| -from transformers.tokenization_utils import PreTrainedTokenizer |
43 | 49 | from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
|
44 | 50 | from transformers.modeling_flax_utils import (
|
45 | 51 | ACT2FN,
|
46 | 52 | FlaxPreTrainedModel,
|
47 | 53 | append_call_sample_docstring,
|
48 | 54 | )
|
| 55 | +from transformers.tokenization_utils import PreTrainedTokenizer |
49 | 56 | from transformers.utils import (
|
50 | 57 | add_start_docstrings,
|
51 | 58 | add_start_docstrings_to_model_forward,
|
52 | 59 | logging,
|
53 | 60 | )
|
54 | 61 |
|
55 |
| - |
56 |
| -from ml_collections import ConfigDict |
57 |
| -from ml_collections.config_dict import config_dict |
58 |
| -from mlxu import function_args_to_config, load_pickle, open_file |
59 |
| - |
60 |
| -from EasyLM.bpt import blockwise_ffn, blockwise_attn |
61 |
| -from EasyLM.jax_utils import ( |
62 |
| - with_sharding_constraint, |
63 |
| - get_jax_mesh, |
64 |
| - get_gradient_checkpoint_policy, |
65 |
| -) |
66 |
| - |
67 |
| - |
68 | 62 | LLAMA_STANDARD_CONFIGS = {
|
69 | 63 | '7b': {
|
70 | 64 | 'vocab_size': 32000,
|
|
0 commit comments