Skip to content

Commit

Permalink
First try on integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Jun 29, 2023
1 parent 60d6168 commit 63f614c
Show file tree
Hide file tree
Showing 11 changed files with 325 additions and 3 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Integration Tests

permissions:
contents: write

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
workflow_dispatch:

env:
CARGO_TERM_COLOR: always


jobs:
build:
strategy:
# Don't stop testing if an architecture fails
fail-fast: false
matrix:
model: [llama, gptneox, gptj, mpt, bloom]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
submodules: recursive
- uses: dtolnay/rust-toolchain@stable
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get install -y \
libssl-dev \
pkg-config \
zlib1g-dev
- name: Run Integration Tests for ${{ matrix.model }}
run: cargo run --release --bin llm-test ${{ matrix.model }}
continue-on-error: true
# Upload test results
- uses: actions/upload-artifact@v3
if: always()
with:
name: test-reports
path: ./.tests/results/*.json
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/target
/models
.DS_Store
.DS_Store
/.tests
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ members = [
"binaries/*"
]
resolver = "2"
default-members = ["binaries/llm-cli", "crates/llm"]
default-members = ["binaries/llm-cli", "crates/llm","binaries/llm-test"]

[workspace.package]
repository = "https://github.com/rustformers/llm"
Expand Down
33 changes: 33 additions & 0 deletions binaries/llm-test/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[package]
edition = "2021"
name = "llm-test"
version = "0.2.0-dev"
repository = { workspace = true }
license = { workspace = true }

[[bin]]
name = "llm-test"
path = "src/main.rs"

[dependencies]
llm = { path = "../../crates/llm", version = "0.2.0-dev" }
reqwest = "0.11.9"
indicatif = "0.16.2"
tokio = { version = "1.14.0", features = ["full"] }
tokio-stream = "0.1.8"
tokio-util = "0.7.1"
serde = "1.0.130"
serde_json = "1.0.67"
bytes = "1.0.1"
rand = { workspace = true }

[dev-dependencies]
rusty-hook = "^0.11.2"

[features]
cublas = ["llm/cublas"]
clblast = ["llm/clblast"]
metal = ["llm/metal"]

# Falcon is off by default. See `llm_falcon`'s module documentation for more information.
falcon = ["llm/falcon"]
5 changes: 5 additions & 0 deletions binaries/llm-test/configs/bloom.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"url": "https://huggingface.co/rustformers/bloom-ggml/resolve/main/bloom-560m-q4_0.bin",
"filename": "bloom.bin",
"architecture": "bloom"
}
5 changes: 5 additions & 0 deletions binaries/llm-test/configs/gptj.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"url": "https://huggingface.co/rustformers/gpt-j-ggml/resolve/main/gpt-j-6b-q4_0-ggjt.bin",
"filename": "gptj.bin",
"architecture": "gptj"
}
5 changes: 5 additions & 0 deletions binaries/llm-test/configs/gptneox.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"url": "https://huggingface.co/rustformers/redpajama-3b-ggml/resolve/main/RedPajama-INCITE-Base-3B-v1-q4_0-ggjt.bin",
"filename": "gptneox.bin",
"architecture": "gptneox"
}
5 changes: 5 additions & 0 deletions binaries/llm-test/configs/llama.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"url": "https://huggingface.co/rustformers/open-llama-ggml/resolve/main/open_llama_3b-q4_0-ggjt.bin",
"filename": "llama.bin",
"architecture": "llama"
}
5 changes: 5 additions & 0 deletions binaries/llm-test/configs/mpt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"url": "https://huggingface.co/rustformers/mpt-7b-ggml/resolve/main/mpt-7b-q4_0-ggjt.bin",
"filename": "mpt.bin",
"architecture": "mpt"
}
217 changes: 217 additions & 0 deletions binaries/llm-test/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
extern crate indicatif;
extern crate reqwest;
extern crate tokio;

use indicatif::{ProgressBar, ProgressStyle};
use llm::InferenceStats;
use rand::rngs::StdRng;
use rand::SeedableRng;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::collections::HashMap;
use std::convert::Infallible;
use std::env;
use std::fs::{self, File};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::str::FromStr;

async fn download_file(url: &str, local_path: &PathBuf) -> Result<(), Box<dyn std::error::Error>> {
if Path::new(local_path).exists() {
println!("Model already exists at {}", local_path.to_str().unwrap());
return Ok(());
}

let client = Client::new();

let mut res = client.get(url).send().await?;
let total_size = res.content_length().ok_or("Failed to get content length")?;

let pb = ProgressBar::new(total_size);
pb.set_style(ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes} ({eta})")
.progress_chars("#>-"));

let mut file = File::create(local_path)?;
let mut downloaded: u64 = 0;

while let Some(chunk) = res.chunk().await? {
file.write_all(&chunk)?;
let new = min(downloaded + (chunk.len() as u64), total_size);
downloaded = new;
pb.set_position(new);
}

pb.finish_with_message("Download complete");

Ok(())
}

#[derive(Deserialize, Debug)]
struct TestCase {
url: String,
filename: String,
architecture: String,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
//This shoud be done with clap but I'm lazy
let args: Vec<String> = env::args().collect();

let mut specific_model = None;
if args.len() > 1 {
println!("Testing architecture: {}", args[1].to_lowercase());
specific_model = Some(args[1].to_lowercase());
} else {
println!("Testing all architectures.");
}

let mut configs = HashMap::new();
let cwd = std::env::current_dir()?;
let configs_dir = cwd.join("binaries/llm-test/configs");
let download_dir = cwd.join(".tests/models");
fs::create_dir_all(&download_dir)?;
let results_dir = cwd.join(".tests/results");
fs::create_dir_all(&results_dir)?;

for entry in fs::read_dir(configs_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_file() {
if let Some(extension) = path.extension() {
if extension == "json" {
let file_name = path.file_stem().unwrap().to_str().unwrap().to_string();
let config: TestCase = serde_json::from_str(&fs::read_to_string(path)?)?;
configs.insert(file_name, config);
}
}
}
}

if let Some(specific_architecture) = specific_model {
if let Some(config) = configs.get(&specific_architecture) {
println!("Key: {}, Config: {:?}", specific_architecture, config);
test_model(config, &download_dir, &results_dir).await?;
} else {
println!("No config found for {}", specific_architecture);
}
} else {
for (key, config) in &configs {
println!("Key: {}, Config: {:?}", key, config);
test_model(config, &download_dir, &results_dir).await?;
}
}
println!("All tests passed!");
Ok(())
}

#[derive(Serialize)]
pub struct Report {
pub could_loaded: bool,
pub inference_stats: Option<InferenceStats>,
pub error: Option<String>,
pub output: String,
}

async fn test_model(
config: &TestCase,
download_dir: &Path,
results_dir: &Path,
) -> Result<(), Box<dyn std::error::Error>> {
println!("Testing architecture: `{}` ...", config.architecture);

let local_path = download_dir.join(&config.filename);

//download the model
download_file(&config.url, &local_path).await?;

let now = std::time::Instant::now();

let architecture = llm::ModelArchitecture::from_str(&config.architecture)?;
//load the model
let model = llm::load_dynamic(
architecture,
&local_path,
llm::TokenizerSource::Embedded,
Default::default(),
llm::load_progress_callback_stdout,
)
.unwrap_or_else(|err| panic!("Failed to load {architecture} model from {local_path:?}: {err}"));

println!(
"Model fully loaded! Elapsed: {}ms",
now.elapsed().as_millis()
);

//run the model
let mut session = model.start_session(Default::default());

let prompt = "write a story about a lama riding a crab:";
let mut rng: StdRng = SeedableRng::seed_from_u64(42);
let mut output: String = String::new();

println!("Running inference...");
let res = session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: prompt.into(),
parameters: &llm::InferenceParameters::default(),
play_back_previous_tokens: false,
maximum_token_count: Some(50),
},
// OutputRequest
&mut Default::default(),
|r| match r {
llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => {
output += &t;
Ok(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
},
);
println!("Inference done!");

let inference_results: Option<llm::InferenceStats>;
let error: Option<llm::InferenceError>;

match res {
Ok(result) => {
inference_results = Some(result);
error = None;
}
Err(err) => {
inference_results = None;
error = Some(err);
}
}

//save the results
let report = Report {
could_loaded: true,
inference_stats: inference_results,
error: error.map(|e| format!("{:?}", e)),
output,
};

// Serialize the report to a JSON string
let json_report = serde_json::to_string(&report).unwrap();
let report_path = results_dir.join(format!("{}.json", config.architecture));
match fs::write(report_path, json_report) {
Ok(_) => println!("Report successfully written to file."),
Err(e) => println!("Failed to write report to file: {}", e),
}

if let Some(err) = &report.error {
panic!("Error: {}", err);
}

println!(
"Successfully tested architecture `{}`!",
config.architecture
);

Ok(())
}
3 changes: 2 additions & 1 deletion crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ggml::{Buffer, ComputationGraph, Context, Tensor};
use serde::Serialize;
use std::{fmt::Display, sync::Arc};
use thiserror::Error;

Expand Down Expand Up @@ -734,7 +735,7 @@ pub struct InferenceRequest<'a> {
}

/// Statistics about the inference process.
#[derive(Debug, Clone, Copy)]
#[derive(Serialize, Debug, Clone, Copy)]
pub struct InferenceStats {
/// How long it took to feed the prompt.
pub feed_prompt_duration: std::time::Duration,
Expand Down

0 comments on commit 63f614c

Please sign in to comment.