Skip to content

Commit

Permalink
Add check_env_var
Browse files Browse the repository at this point in the history
  • Loading branch information
limberc committed May 23, 2024
1 parent 9ebdfec commit 77f366a
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,37 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from demo import train_and_merge
from main import FLOCK_API_KEY


def check_env_var(VAR_NAME, PASSED_VAR):
if os.environ["VAR_NAME"] is None and PASSED_VAR is None:
raise Exception(f"{VAR_NAME} not found in environment variables. "
f"You should assign your Hugging Face token to the {VAR_NAME} variable. "
f"Or your can directly pass the {VAR_NAME} to TrainNode directly.")
return os.environ["VAR_NAME"] if PASSED_VAR is None else PASSED_VAR


class TrainNode:
FED_LEDGER_BASE_URL = "https://fed-ledger-prod.flock.io/api/v1"

def __init__(self, task_id: int = 2,
use_proxy: bool = False):
training_args: dict = None,
HF_TOKEN: str = None,
HG_USERNAME: str = None,
FLOCK_API_KEY: str = None):
self.HF_TOKEN = check_env_var("HF_TOKEN", HF_TOKEN)
self.HG_USERNAME = check_env_var("HG_USERNAME", HG_USERNAME)
self.FLOCK_API_KEY = check_env_var("FLOCK_API_KEY", FLOCK_API_KEY)
self.task_id = task_id
self.use_proxy = use_proxy
data = self.get_task_data()
self.download_data(data)
self.content_length = data["context_length"]
self.training_args = {
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 12,
"context_length": self.content_length,
} if training_args is None else training_args

def get_task_data(self):
response = requests.get(
Expand All @@ -46,10 +64,10 @@ def download_data(self, data):

def train(self):
logger.info("Start to train the model...")
train_and_merge(context_length=self.content_length)
train_and_merge(context_length=self.content_length, **self.training_args)

def push(self):
hg_repo_id = "gemma-2b-flock-" + str(int(time.time()))
def push(self, hg_repo_id: str = None):
hg_repo_id = "gemma-2b-flock-" + str(int(time.time())) if hg_repo_id is None else hg_repo_id
# Load model
model = AutoModelForCausalLM.from_pretrained(
"merged_model",
Expand All @@ -62,15 +80,15 @@ def push(self):
"merged_model",
)
tokenizer.push_to_hub(
repo_id=hg_repo_id, use_temp_dir=True, token=os.environ["HF_TOKEN"]
repo_id=hg_repo_id, use_temp_dir=True, token=self.HF_TOKEN
)
logger.info("SUCCESSFULLY PUSHED TOKENIZER TO HUB")
logger.info("Start to push the model to the hub...")
model.push_to_hub(
repo_id=hg_repo_id, use_temp_dir=True, token=os.environ["HF_TOKEN"]
repo_id=hg_repo_id, use_temp_dir=True, token=self.HF_TOKEN
)
logger.info("SUCCESSFULLY PUSHED MODEL TO HUB")
self.submit_task(hg_repo_id)
self.submit_task(f"{self.HG_USERNAME}/{hg_repo_id}")

def submit_task(self, hg_repo_id: str):
payload = json.dumps(
Expand All @@ -79,7 +97,7 @@ def submit_task(self, hg_repo_id: str):
response = requests.post(
f"{self.FED_LEDGER_BASE_URL}/tasks/submit-result",
headers={
"flock-api-key": FLOCK_API_KEY,
"flock-api-key": self.FLOCK_API_KEY,
"Content-Type": "application/json",
},
data=payload,
Expand Down

0 comments on commit 77f366a

Please sign in to comment.