Skip to content

Commit

Permalink
feat: add log exporting to e2e tests
Browse files Browse the repository at this point in the history
Currently, the training library runs through a series of end-to-end tests which ensure there are
no bugs in the code being tested. However; we do not perform any form of validation to assure that
the training logic and quality has not diminished.

This presents an issue where we can potentially be "correct" in the sense of no hard errors being hit,
but invisible bugs may be introduced which cause models to regress in training quality, or other
bugs that plague the models themselves to seep in.

This commit fixes that problem by introducng the ability to export the training loss data itself
from the test and rendering the loss curve using matplotlib.

Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com>
  • Loading branch information
RobotSail committed Oct 25, 2024
1 parent 466474a commit 8f77076
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 0 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/e2e-nvidia-l4-x1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,22 @@ jobs:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
. venv/bin/activate
# set preserve to true so we can retain the logs
export PRESERVE=1
./scripts/e2e-ci.sh -m
# we know that the file will be named something like f"/training_params_and_metrics_global{os.environ['RANK']}.jsonl" in python
# and we know that it will be written into a directory created by mktemp -d.
# Given this information, we can use the following command to find the file:
log_file=$(find /tmp -name "training_params_and_metrics_global0.jsonl")
mv "${log_file}" training-log.jsonl
- name: Upload training logs
uses: actions/upload-artifact@v4
with:
name: training-log.jsonl
path: ./instructlab/training-log.jsonl
retention-days: 1
overwrite: true

stop-medium-ec2-runner:
needs:
Expand All @@ -150,12 +165,14 @@ jobs:
runs-on: ubuntu-latest
if: ${{ always() }}
steps:

- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ secrets.AWS_REGION }}


- name: Stop EC2 runner
uses: machulav/ec2-github-runner@fcfb31a5760dad1314a64a0e172b78ec6fc8a17e # v2.3.6
Expand All @@ -164,6 +181,34 @@ jobs:
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-medium-ec2-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-medium-ec2-runner.outputs.ec2-instance-id }}

- name: Download loss data
id: download-logs
uses: actions/download-artifact@v4
with:
name: training-log.jsonl
path: downloaded-data

- name: Install dependencies
run: |
pip install -r requirements-dev.txt
- name: Try to upload to s3
run: |

Check failure on line 197 in .github/workflows/e2e-nvidia-l4-x1.yml

View workflow job for this annotation

GitHub Actions / actionlint

shellcheck reported issue in this script: SC2005:style:1:6: Useless echo? Instead of 'echo $(cmd)', just use 'cmd'
echo "$(which aws)"
output_file='./test.md'
python scripts/create-loss-graph.py \
--log-file "${{ steps.download-logs.outputs.download-path }}/training-log.jsonl" \
--output-file "${output_file}" \
--aws-region "${{ vars.AWS_REGION }}" \
--bucket-name "${{ vars.AWS_S3_BUCKET_NAME }}" \
--base-branch "${{ github.event.pull_request.base.ref }}" \
--pr-number "${{ github.event.pull_request.number }}" \
--head-sha "${{ github.event.pull_request.head.sha }}" \
--origin-repository "${{ github.repository }}"
cat "${output_file}" >> "${GITHUB_STEP_SUMMARY}"


e2e-medium-workflow-complete:
# we don't want to block PRs on failed EC2 cleanup
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

-r requirements.txt

matplotlib
numpy
pre-commit>=3.0.4,<5.0
pylint>=2.16.2,<4.0
pylint-pydantic
Expand Down
186 changes: 186 additions & 0 deletions scripts/create-loss-graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from argparse import ArgumentParser
from pathlib import Path
from subprocess import run
from typing import Dict, List
import json

# Third Party
from matplotlib import pyplot as plt
from pydantic import BaseModel


class Arguments(BaseModel):
log_file: str | None = None
output_file: str
aws_region: str
bucket_name: str
base_branch: str
pr_number: str
head_sha: str
origin_repository: str


def render_image(loss_data: List[float], outfile: Path) -> str:
# create the plot
plt.figure()
plt.plot(loss_data)
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training performance over fixed dataset")

if outfile.exists():
outfile.unlink()

plt.savefig(outfile, format="png")


def contents_from_file(log_file: Path) -> List[Dict]:
if not log_file.exists():
raise FileNotFoundError(f"Log file {log_file} does not exist")
if log_file.is_dir():
raise ValueError(f"Log file {log_file} is a directory")
with open(log_file, "r") as f:
return [json.loads(l) for l in f.read().splitlines()]


def read_loss_data(log_file: Path) -> List[float]:
if not log_file:
raise ValueError("log_file must be provided when source is file")
contents = contents_from_file(log_file)

# select the loss data
loss_data = [item["total_loss"] for item in contents if "total_loss" in item]

if not loss_data:
raise ValueError("Loss data is empty")

# ensure that the loss data is valid
if not all(isinstance(l, float) for l in loss_data):
raise ValueError("Loss data must be a list of floats")

return loss_data


def write_to_s3(
file: Path,
bucket_name: str,
destination: str,
):
if not file.exists():
raise RuntimeError(f"File {file} does not exist")

s3_path = f"s3://{bucket_name}/{destination}"
results = run(
["aws", "s3", "cp", str(file), s3_path], capture_output=True, check=True
)
if results.returncode != 0:
raise RuntimeError(f"failed to upload to s3: {results.stderr.decode('utf-8')}")
else:
print(results.stdout.decode("utf-8"))


def get_destination_path(base_ref: str, pr_number: str, head_sha: str):
return f"pulls/{base_ref}/{pr_number}/{head_sha}/loss-graph.png"


def write_md_file(
output_file: Path, url: str, pr_number: str, head_sha: str, origin_repository: str
):
commit_url = f"https://github.com/{origin_repository}/commit/{head_sha}"
md_template = f"""
# Loss Graph for PR {args.pr_number} ([{args.head_sha[:7]}]({commit_url}))
![Loss Graph]({url})
"""
output_file.write_text(md_template, encoding="utf-8")


def get_url(bucket_name: str, destination: str, aws_region: str) -> str:
return f"https://{bucket_name}.s3.{aws_region}.amazonaws.com/{destination}"


def main(args: Arguments):
# first things first, we create the png file to upload to S3
log_file = Path(args.log_file)
loss_data = read_loss_data(log_file=log_file)
output_image = Path("/tmp/loss-graph.png")
output_file = Path(args.output_file)
render_image(loss_data=loss_data, outfile=output_image)
destination_path = get_destination_path(
base_ref=args.base_branch, pr_number=args.pr_number, head_sha=args.head_sha
)
write_to_s3(
file=output_image, bucket_name=args.bucket_name, destination=destination_path
)
s3_url = get_url(
bucket_name=args.bucket_name,
destination=destination_path,
aws_region=args.aws_region,
)
write_md_file(
output_file=output_file,
url=s3_url,
pr_number=args.pr_number,
head_sha=args.head_sha,
origin_repository=args.origin_repository,
)
print(f"Loss graph uploaded to '{s3_url}'")
print(f"Markdown file written to '{output_file}'")


if __name__ == "__main__":
parser = ArgumentParser()

parser.add_argument(
"--log-file",
type=str,
required=True,
help="The log file to read the loss data from.",
)
parser.add_argument(
"--output-file",
type=str,
required=True,
help="The output file where the resulting markdown will be written.",
)
parser.add_argument(
"--aws-region",
type=str,
required=True,
help="S3 region to which the bucket belongs.",
)
parser.add_argument(
"--bucket-name", type=str, required=True, help="The S3 bucket name"
)
parser.add_argument(
"--base-branch",
type=str,
required=True,
help="The base branch being merged to.",
)
parser.add_argument("--pr-number", type=str, required=True, help="The PR number")
parser.add_argument(
"--head-sha", type=str, required=True, help="The head SHA of the PR"
)
parser.add_argument(
"--origin-repository",
type=str,
required=True,
help="The repository to which the originating branch belongs to.",
)

args = parser.parse_args()

arguments = Arguments(
log_file=args.log_file,
output_file=args.output_file,
aws_region=args.aws_region,
bucket_name=args.bucket_name,
base_branch=args.base_branch,
pr_number=args.pr_number,
head_sha=args.head_sha,
origin_repository=args.origin_repository,
)
main(arguments)

0 comments on commit 8f77076

Please sign in to comment.