Skip to content

Commit

Permalink
Move training_args to train method.
Browse files Browse the repository at this point in the history
  • Loading branch information
limberc committed May 23, 2024
1 parent 5272629 commit 1b8a10d
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ class TrainNode:
FED_LEDGER_BASE_URL = "https://fed-ledger-prod.flock.io/api/v1"

def __init__(self, task_id: int = 2,
training_args: dict = None,
HF_TOKEN: str = None,
HF_USERNAME: str = None,
FLOCK_API_KEY: str = None):
Expand All @@ -33,11 +32,6 @@ def __init__(self, task_id: int = 2,
data = self.get_task_data()
self.download_data(data)
self.content_length = data["context_length"]
self.training_args = {
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 12,
} if training_args is None else training_args

def get_task_data(self):
response = requests.get(
Expand All @@ -61,9 +55,14 @@ def download_data(self, data):
for chunk in r.iter_content(chunk_size=128):
f.write(chunk)

def train(self):
def train(self, training_args: dict = None, ):
training_args = {
"num_train_epochs": 3,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 12,
} if training_args is None else training_args
logger.info("Start to train the model...")
train_and_merge(context_length=self.content_length, **self.training_args)
train_and_merge(context_length=self.content_length, **training_args)

def load_merged_model(self, model_path='merged_model'):
model = AutoModelForCausalLM.from_pretrained(
Expand Down

0 comments on commit 1b8a10d

Please sign in to comment.