Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions problemtools/formatversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,22 @@ class FormatData:
name: str
statement_directory: str
statement_extensions: list[str]
output_validator_directory: str


FORMAT_DATACLASSES = {
VERSION_LEGACY: FormatData(name=VERSION_LEGACY, statement_directory='problem_statement', statement_extensions=['tex']),
VERSION_2023_07: FormatData(name=VERSION_2023_07, statement_directory='statement', statement_extensions=['md', 'tex']),
VERSION_LEGACY: FormatData(
name=VERSION_LEGACY,
statement_directory='problem_statement',
statement_extensions=['tex'],
output_validator_directory='output_validators',
),
VERSION_2023_07: FormatData(
name=VERSION_2023_07,
statement_directory='statement',
statement_extensions=['md', 'tex'],
output_validator_directory='output_validator',
),
}
FORMAT_DATACLASSES['2023-07'] = FORMAT_DATACLASSES[VERSION_2023_07] # Accept non-draft version string too

Expand Down
50 changes: 27 additions & 23 deletions problemtools/verifyproblem.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def check(self, context: Context) -> bool:
self.warning(
f'Answer file ({anssize:.1f} Mb) is within 50% of output limit ({outputlim} Mb), you might want to increase output limit'
)
if not self._problem.get(ProblemTestCases)['is_interactive']:
if not self._problem.getMetadata().is_interactive():
val_res = self._problem.getProblemPart(OutputValidators).validate(self, self.ansfile)
if val_res.verdict != 'AC':
if self.is_in_sample_group():
Expand Down Expand Up @@ -339,7 +339,7 @@ def run_submission(self, sub, runner: Runner, context: Context) -> Result:

def run_submission_real(self, sub, context: Context, timelim: int, timelim_low: int, timelim_high: int) -> Result:
# This may be called off-main thread.
if self._problem.get(ProblemTestCases)['is_interactive']:
if self._problem.getMetadata().is_interactive():
res_high = self._problem.getProblemPart(OutputValidators).validate_interactive(
self, sub, timelim_high, self._problem.getProblemPart(Submissions)
)
Expand Down Expand Up @@ -543,15 +543,15 @@ def check(self, context: Context) -> bool:
if field not in TestCaseGroup._DEFAULT_CONFIG.keys():
self.warning(f"Unknown key '{field}' in '{os.path.join(self._datadir, 'testdata.yaml')}'")

if not self._problem.get(ProblemTestCases)['is_scoring']:
if not self._problem.getMetadata().is_scoring():
for key in TestCaseGroup._SCORING_ONLY_KEYS:
if self.config.get(key) is not None:
self.error(f"Key '{key}' is only applicable for scoring problems, this is a pass-fail problem")

if self.config['on_reject'] not in ['break', 'continue']:
self.error(f"Invalid value '{self.config['on_reject']}' for on_reject policy")

if self._problem.get(ProblemTestCases)['is_scoring']:
if self._problem.getMetadata().is_scoring():
# Check grading
try:
score_range = self.config['range']
Expand Down Expand Up @@ -714,7 +714,7 @@ def aggregate_results(self, sub, sub_results: list[SubmissionResult], shadow_res
if sub_results:
res.testcase = sub_results[-1].testcase
res.additional_info = sub_results[-1].additional_info
if self._problem.get(ProblemTestCases)['is_scoring']:
if self._problem.getMetadata().is_scoring():
res.score = score
min_score, max_score = self.get_score_range()
if score is not None and not (min_score <= score <= max_score) and not self._seen_oob_scores:
Expand All @@ -738,20 +738,17 @@ class ProblemStatement(ProblemPart):
PART_NAME = 'statement'

def setup(self):
self.format_data = formatversion.get_format_data(self.problem.probdir)
if not self.format_data:
raise NotImplementedError('No version selected.')
self.debug(' Loading problem statement')
self.statement_regex = re.compile(
r'problem(\.([a-z]{2,3}|[a-z]{2}-[A-Z]{2}))?\.(%s)$' % ('|'.join(self.format_data.statement_extensions))
r'problem(\.([a-z]{2,3}|[a-z]{2}-[A-Z]{2}))?\.(%s)$' % ('|'.join(self.problem.format.statement_extensions))
)
dir = os.path.join(self.problem.probdir, self.format_data.statement_directory)
dir = os.path.join(self.problem.probdir, self.problem.format.statement_directory)
if os.path.isdir(dir):
self.statements = [
(m.group(0), m.group(2) or '') for file in os.listdir(dir) if (m := re.search(self.statement_regex, file))
]
else:
self.error(f'No directory named {self.format_data.statement_directory} found')
self.error(f'No directory named {self.problem.format.statement_directory} found')
self.statements = []

return self.get_config()
Expand All @@ -763,10 +760,10 @@ def check(self, context: Context) -> bool:

if not self.statements:
allowed_statements = ', '.join(
f'problem.{ext}, problem.[a-z][a-z].{ext}' for ext in self.format_data.statement_extensions
f'problem.{ext}, problem.[a-z][a-z].{ext}' for ext in self.problem.format.statement_extensions
)
self.error(
f'No problem statements found (expected file of one of following forms in directory {self.format_data.statement_directory}/: {allowed_statements})'
f'No problem statements found (expected file of one of following forms in directory {self.problem.format.statement_directory}/: {allowed_statements})'
)

langs = [lang or 'en' for _, lang in self.statements]
Expand Down Expand Up @@ -808,7 +805,7 @@ def __str__(self) -> str:
def get_config(self) -> dict[str, dict[str, str]]:
ret: dict[str, dict[str, str]] = {'name': {}}
for filename, lang in self.statements:
dir = os.path.join(self.problem.probdir, self.format_data.statement_directory)
dir = os.path.join(self.problem.probdir, self.problem.format.statement_directory)
with open(os.path.join(dir, filename)) as f:
stmt = f.read()
hit = re.search(r'\\problemname{(.*)}', stmt, re.MULTILINE)
Expand Down Expand Up @@ -847,14 +844,14 @@ def setup(self):

try:
self._metadata = metadata.parse_metadata(
formatversion.get_format_data(self.problem.probdir),
self.problem.format,
self._data,
self.problem.get(ProblemStatement).get('name', {}),
)
self.problem.setMetadata(self._metadata)
except ValidationError as e:
# This should likely be a fatal error, but I'm not sure there's a clean way to fail from setup
error_str = '\n'.join([f' {"->".join(str(err["loc"]))}: {err["msg"]}' for err in e.errors()])
error_str = '\n'.join([f' {"->".join((str(loc) for loc in err["loc"]))}: {err["msg"]}' for err in e.errors()])
self.error(f'Failed parsing problem.yaml. Found {len(e.errors())} errors:\n{error_str}')
return {}

Expand Down Expand Up @@ -927,14 +924,12 @@ class ProblemTestCases(ProblemPart):

@staticmethod
def setup_dependencies():
return {ProblemConfig}
return {ProblemConfig} # We need this as the TestCaseGroup constructor reads config

def setup(self):
self.testcase_by_infile = {}
return {
'root_group': TestCaseGroup(self.problem, self.PART_NAME),
'is_interactive': self.problem.getMetadata().is_interactive(),
'is_scoring': self.problem.getMetadata().is_scoring(),
}

def check(self, context: Context) -> bool:
Expand Down Expand Up @@ -1239,7 +1234,7 @@ class OutputValidators(ProblemPart):

def setup(self):
self._validators = run.find_programs(
os.path.join(self.problem.probdir, 'output_validators'),
os.path.join(self.problem.probdir, self.problem.format.output_validator_directory),
language_config=self.problem.language_config,
work_dir=self.problem.tmpdir,
)
Expand Down Expand Up @@ -1268,7 +1263,7 @@ def check(self, context: Context) -> bool:

if self.problem.getMetadata().legacy_validation == 'default' and self._validators:
self.error('There are validator programs but problem.yaml has validation = "default"')
elif self.problem.getMetadata().legacy_validation != 'default' and not self._validators:
elif self.problem.getMetadata().legacy_validation.startswith('custom') and not self._validators:
self.error('problem.yaml specifies custom validator but no validator programs found')

if self.problem.getMetadata().legacy_validation == 'default' and self._default_validator is None:
Expand Down Expand Up @@ -1365,7 +1360,9 @@ def _parse_validator_results(self, val, status: int, feedbackdir, testcase: Test

def _actual_validators(self) -> list:
vals = self._validators
if self.problem.getMetadata().legacy_validation == 'default':
if self.problem.getMetadata().legacy_validation == 'default' or (
self.problem.format.name == formatversion.VERSION_2023_07 and not vals
):
vals = [self._default_validator]
return [val for val in vals if val is not None]

Expand Down Expand Up @@ -1673,7 +1670,7 @@ def full_score_finite(self) -> bool:
def fully_accepted(self, result: SubmissionResult) -> bool:
min_score, max_score = self.problem.get(ProblemTestCases)['root_group'].get_score_range()
best_score = min_score if self.problem.getMetadata().legacy_grading.objective == 'min' else max_score
return result.verdict == 'AC' and (not self.problem.get(ProblemTestCases)['is_scoring'] or result.score == best_score)
return result.verdict == 'AC' and (not self.problem.getMetadata().is_scoring() or result.score == best_score)

def start_background_work(self, context: Context) -> None:
# Send off an early background compile job for each submission and
Expand Down Expand Up @@ -1764,6 +1761,10 @@ def check(self, context: Context) -> bool:
formatversion.VERSION_2023_07: { # TODO: Add all the parts
'config': [ProblemConfig],
'statement': [ProblemStatement, Attachments],
'validators': [InputValidators, OutputValidators],
'graders': [Graders],
'data': [ProblemTestCases],
'submissions': [Submissions],
},
}

Expand All @@ -1789,6 +1790,7 @@ def __init__(self, probdir: str, parts: dict[str, list[type]] = PROBLEM_FORMATS[
self.shortname: str | None = os.path.basename(self.probdir)
super().__init__(self.shortname)
self.language_config = languages.load_language_config()
self.format = formatversion.get_format_data(self.probdir)
self._data: dict[str, dict] = {}
self._metadata: metadata.Metadata | None = None
self.debug(f'Problem-format: {parts}')
Expand Down Expand Up @@ -1868,6 +1870,8 @@ def check(self, args: argparse.Namespace) -> tuple[int, int]:
try:
if not re.match('^[a-z0-9]+$', self.shortname):
self.error(f"Invalid shortname '{self.shortname}' (must be [a-z0-9]+)")
if self.format.name == formatversion.VERSION_2023_07:
self.warning(f'Support for version {self.format.name} is very incomplete. Verification may not work as expected.')

self._check_symlinks()

Expand Down