Skip to content

Commit e245382

Browse files
authored
Merge pull request #83 from github/threads
Parallelize extraction
2 parents c35283c + 83a2878 commit e245382

File tree

5 files changed

+204
-54
lines changed

5 files changed

+204
-54
lines changed

.github/workflows/dataset_measure.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jobs:
2727
unzip -q codeql-linux64.zip
2828
env:
2929
GITHUB_TOKEN: ${{ github.token }}
30+
CODEQL_THREADS: 4 # TODO: remove this once it's set by the CLI
3031
- uses: actions/cache@v2
3132
with:
3233
path: |
@@ -46,6 +47,7 @@ jobs:
4647
run: |
4748
codeql/codeql database create \
4849
--search-path "${{ github.workspace }}" \
50+
--threads 4 \
4951
--language ruby --source-root "${{ github.workspace }}/repo" \
5052
"${{ runner.temp }}/database"
5153
- name: Measure database

Cargo.lock

Lines changed: 110 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

extractor/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ tree-sitter-ruby = { git = "https://github.com/tree-sitter/tree-sitter-ruby.git"
1414
clap = "2.33"
1515
tracing = "0.1"
1616
tracing-subscriber = { version = "0.2", features = ["env-filter"] }
17+
rayon = "1.5.0"
18+
num_cpus = "1.13.0"

extractor/src/extractor.rs

Lines changed: 31 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -148,55 +148,40 @@ impl TrapWriter {
148148
}
149149
}
150150

151-
pub struct Extractor {
152-
pub parser: Parser,
153-
pub schema: NodeTypeMap,
154-
}
155-
156-
pub fn create(language: Language, schema: NodeTypeMap) -> Extractor {
157-
let mut parser = Parser::new();
158-
parser.set_language(language).unwrap();
151+
/// Extracts the source file at `path`, which is assumed to be canonicalized.
152+
pub fn extract(language: Language, schema: &NodeTypeMap, path: &Path) -> std::io::Result<Program> {
153+
let span = span!(
154+
Level::TRACE,
155+
"extract",
156+
file = %path.display()
157+
);
159158

160-
Extractor { parser, schema }
161-
}
159+
let _enter = span.enter();
162160

163-
impl Extractor {
164-
/// Extracts the source file at `path`, which is assumed to be canonicalized.
165-
pub fn extract<'a>(&'a mut self, path: &Path) -> std::io::Result<Program> {
166-
let span = span!(
167-
Level::TRACE,
168-
"extract",
169-
file = %path.display()
170-
);
161+
info!("extracting: {}", path.display());
171162

172-
let _enter = span.enter();
173-
174-
info!("extracting: {}", path.display());
175-
176-
let source = std::fs::read(&path)?;
177-
let tree = &self
178-
.parser
179-
.parse(&source, None)
180-
.expect("Failed to parse file");
181-
let mut trap_writer = new_trap_writer();
182-
trap_writer.comment(format!("Auto-generated TRAP file for {}", path.display()));
183-
let file_label = &trap_writer.populate_file(path);
184-
let mut visitor = Visitor {
185-
source: &source,
186-
trap_writer: trap_writer,
187-
// TODO: should we handle path strings that are not valid UTF8 better?
188-
path: format!("{}", path.display()),
189-
file_label: *file_label,
190-
token_counter: 0,
191-
toplevel_child_counter: 0,
192-
stack: Vec::new(),
193-
schema: &self.schema,
194-
};
195-
traverse(&tree, &mut visitor);
196-
197-
&self.parser.reset();
198-
Ok(Program(visitor.trap_writer.trap_output))
199-
}
163+
let mut parser = Parser::new();
164+
parser.set_language(language).unwrap();
165+
let source = std::fs::read(&path)?;
166+
let tree = parser.parse(&source, None).expect("Failed to parse file");
167+
let mut trap_writer = new_trap_writer();
168+
trap_writer.comment(format!("Auto-generated TRAP file for {}", path.display()));
169+
let file_label = &trap_writer.populate_file(path);
170+
let mut visitor = Visitor {
171+
source: &source,
172+
trap_writer: trap_writer,
173+
// TODO: should we handle path strings that are not valid UTF8 better?
174+
path: format!("{}", path.display()),
175+
file_label: *file_label,
176+
token_counter: 0,
177+
toplevel_child_counter: 0,
178+
stack: Vec::new(),
179+
schema,
180+
};
181+
traverse(&tree, &mut visitor);
182+
183+
parser.reset();
184+
Ok(Program(visitor.trap_writer.trap_output))
200185
}
201186

202187
/// Normalizes the path according the common CodeQL specification. Assumes that

extractor/src/main.rs

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
mod extractor;
22

3+
extern crate num_cpus;
4+
35
use clap;
46
use flate2::write::GzEncoder;
7+
use rayon::prelude::*;
58
use std::fs;
69
use std::io::{BufRead, BufWriter, Write};
710
use std::path::{Path, PathBuf};
@@ -42,6 +45,39 @@ impl TrapCompression {
4245
}
4346
}
4447

48+
/**
49+
* Gets the number of threads the extractor should use, by reading the
50+
* CODEQL_THREADS environment variable and using it as described in the
51+
* extractor spec:
52+
*
53+
* "If the number is positive, it indicates the number of threads that should
54+
* be used. If the number is negative or zero, it should be added to the number
55+
* of cores available on the machine to determine how many threads to use
56+
* (minimum of 1). If unspecified, should be considered as set to 1."
57+
*/
58+
fn num_codeql_threads() -> usize {
59+
match std::env::var("CODEQL_THREADS") {
60+
// Use 1 thread if the environment variable isn't set.
61+
Err(_) => 1,
62+
63+
Ok(num) => match num.parse::<i32>() {
64+
Ok(num) if num <= 0 => {
65+
let reduction = -num as usize;
66+
num_cpus::get() - reduction
67+
}
68+
Ok(num) => num as usize,
69+
70+
Err(_) => {
71+
tracing::error!(
72+
"Unable to parse CODEQL_THREADS value '{}'; defaulting to 1 thread.",
73+
&num
74+
);
75+
1
76+
}
77+
},
78+
}
79+
}
80+
4581
fn main() -> std::io::Result<()> {
4682
tracing_subscriber::fmt()
4783
.with_target(false)
@@ -50,6 +86,21 @@ fn main() -> std::io::Result<()> {
5086
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
5187
.init();
5288

89+
let num_threads = num_codeql_threads();
90+
tracing::info!(
91+
"Using {} {}",
92+
num_threads,
93+
if num_threads == 1 {
94+
"thread"
95+
} else {
96+
"threads"
97+
}
98+
);
99+
rayon::ThreadPoolBuilder::new()
100+
.num_threads(num_threads)
101+
.build_global()
102+
.unwrap();
103+
53104
let matches = clap::App::new("Ruby extractor")
54105
.version("1.0")
55106
.author("GitHub")
@@ -76,28 +127,28 @@ fn main() -> std::io::Result<()> {
76127

77128
let language = tree_sitter_ruby::language();
78129
let schema = node_types::read_node_types_str(tree_sitter_ruby::NODE_TYPES)?;
79-
let mut extractor = extractor::create(language, schema);
80-
for line in std::io::BufReader::new(file_list).lines() {
81-
let path = PathBuf::from(line?).canonicalize()?;
130+
let lines: std::io::Result<Vec<String>> = std::io::BufReader::new(file_list).lines().collect();
131+
let lines = lines?;
132+
lines.par_iter().try_for_each(|line| {
133+
let path = PathBuf::from(line).canonicalize()?;
82134
let trap_file = path_for(&trap_dir, &path, trap_compression.extension());
83135
let src_archive_file = path_for(&src_archive_dir, &path, "");
84-
let trap = extractor.extract(&path)?;
136+
let trap = extractor::extract(language, &schema, &path)?;
85137
std::fs::create_dir_all(&src_archive_file.parent().unwrap())?;
86138
std::fs::copy(&path, &src_archive_file)?;
87139
std::fs::create_dir_all(&trap_file.parent().unwrap())?;
88140
let trap_file = std::fs::File::create(&trap_file)?;
89141
let mut trap_file = BufWriter::new(trap_file);
90142
match trap_compression {
91143
TrapCompression::None => {
92-
write!(trap_file, "{}", trap)?;
144+
write!(trap_file, "{}", trap)
93145
}
94146
TrapCompression::Gzip => {
95147
let mut compressed_writer = GzEncoder::new(trap_file, flate2::Compression::fast());
96-
write!(compressed_writer, "{}", trap)?;
148+
write!(compressed_writer, "{}", trap)
97149
}
98150
}
99-
}
100-
return Ok(());
151+
})
101152
}
102153

103154
fn path_for(dir: &Path, path: &Path, ext: &str) -> PathBuf {

0 commit comments

Comments
 (0)