Skip to content

Commit a4fb3ac

Browse files
committed
Emit attribute to enable idiomatic subclass (i.e. enum variant) access from Java. Added an integration test to verify this.
1 parent e06f2f0 commit a4fb3ac

File tree

10 files changed

+139
-4
lines changed

10 files changed

+139
-4
lines changed

Tester.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,17 @@ def process_binary_test(test_dir: str, release_mode: bool) -> bool:
127127
test_name = os.path.basename(test_dir)
128128
normalized = normalize_name(test_name)
129129
print(f"|-- Test '{test_name}' ({normalized})")
130-
130+
131+
# If running in release mode, allow tests to opt-out by creating a
132+
# `no_release.flag` file containing a short justification. When present
133+
# we skip the test and show the justification to the user.
134+
no_release_file = os.path.join(test_dir, "no_release.flag")
135+
if release_mode and os.path.exists(no_release_file):
136+
reason = read_from_file(no_release_file).strip()
137+
# Per request: show the justification when skipping
138+
print(f"Skipping: {reason}")
139+
return True
140+
131141
print("|--- 🧼 Cleaning test folder...")
132142
proc = run_command(["cargo", "clean"], cwd=test_dir)
133143
if proc.returncode != 0:
@@ -159,6 +169,16 @@ def process_integration_test(test_dir: str, release_mode: bool) -> bool:
159169
normalized = normalize_name(test_name)
160170
print(f"|-- Test '{test_name}' ({normalized})")
161171

172+
# If running in release mode, allow tests to opt-out by creating a
173+
# `no_release.flag` file containing a short justification. When present
174+
# we skip the test and show the justification to the user.
175+
no_release_file = os.path.join(test_dir, "no_release.flag")
176+
if release_mode and os.path.exists(no_release_file):
177+
reason = read_from_file(no_release_file).strip()
178+
# Per request: show the justification when skipping
179+
print(f"Skipping: {reason}")
180+
return True
181+
162182
print("|--- 🧼 Cleaning test folder...")
163183
run_command(["cargo", "clean"], cwd=test_dir) # Ignore clean failure for now
164184

proguard/default.pro

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
-keep public class * { public static void main(java.lang.String[]); }
1+
-keep public class *
22
-keep class * {
33
<fields>;
44
}
55
-keepclassmembers class * implements * {
66
<methods>;
77
}
88

9-
-keepattributes MethodParameters
9+
-keepattributes MethodParameters
10+
-keepattributes InnerClasses
11+
-keepattributes EnclosingMethod

src/lower2.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ pub fn oomir_to_jvm_bytecode(
172172
methods,
173173
interfaces,
174174
} => {
175+
let mut subclasses = Vec::new();
176+
for (other_dt_name, _) in &module.data_types {
177+
if other_dt_name.starts_with(&format!("{}$", dt_name_oomir)) {
178+
subclasses.push(other_dt_name.clone());
179+
}
180+
}
175181
// Create and serialize the class file for this data type
176182
let dt_bytecode = create_data_type_classfile_for_class(
177183
&dt_name_oomir,
@@ -181,6 +187,7 @@ pub fn oomir_to_jvm_bytecode(
181187
super_class.as_deref().unwrap_or("java/lang/Object"),
182188
interfaces.clone(),
183189
&module,
190+
subclasses,
184191
)?;
185192
generated_classes.insert(dt_name_oomir.clone(), dt_bytecode);
186193
}

src/lower2/jvm_gen.rs

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::oomir::{self, DataTypeMethod, Signature, Type};
66
use ristretto_classfile::{
77
self as jvm, BaseType, ClassAccessFlags, ClassFile, ConstantPool, FieldAccessFlags,
88
MethodAccessFlags, Version,
9-
attributes::{Attribute, Instruction, MaxStack},
9+
attributes::{Attribute, Instruction, MaxStack, InnerClass, NestedClassAccessFlags},
1010
};
1111
use std::collections::HashMap;
1212

@@ -91,6 +91,7 @@ pub(super) fn create_data_type_classfile_for_class(
9191
super_class_name_jvm: &str,
9292
implements_interfaces: Vec<String>,
9393
module: &oomir::Module,
94+
subclasses: Vec<String>,
9495
) -> jvm::Result<Vec<u8>> {
9596
let mut cp = ConstantPool::default();
9697

@@ -244,6 +245,51 @@ pub(super) fn create_data_type_classfile_for_class(
244245
}
245246
}
246247

248+
// --- Add InnerClasses Attribute (for nested/member classes) ---
249+
if !subclasses.is_empty() {
250+
let mut inner_classes_vec: Vec<InnerClass> = Vec::with_capacity(subclasses.len());
251+
252+
for subclass_name in &subclasses {
253+
// Ensure subclass class_info is in the constant pool
254+
let class_info_index = class_file.constant_pool.add_class(subclass_name)?;
255+
256+
// The outer class is this class
257+
let outer_class_info_index = class_file.this_class;
258+
259+
// Derive simple name: part after last '$'. If there's no '$', treat as unnamed (0).
260+
let simple_name_part = subclass_name.rsplit('$').next().unwrap_or(subclass_name);
261+
262+
// If the simple name looks like an anonymous class (all digits), set name_index = 0
263+
let name_index = if simple_name_part.chars().all(|c| c.is_ascii_digit()) {
264+
0
265+
} else if simple_name_part == *subclass_name && !subclass_name.contains('$') {
266+
// No '$' present -> not an inner/member class; leave name_index = 0
267+
0
268+
} else {
269+
class_file
270+
.constant_pool
271+
.add_utf8(simple_name_part)?
272+
};
273+
274+
// Default to PUBLIC | STATIC for generated nested classes. This can be adjusted
275+
// if more precise access info becomes available.
276+
let access_flags = NestedClassAccessFlags::PUBLIC | NestedClassAccessFlags::STATIC;
277+
278+
inner_classes_vec.push(InnerClass {
279+
class_info_index,
280+
outer_class_info_index,
281+
name_index,
282+
access_flags,
283+
});
284+
}
285+
286+
let inner_classes_attr_name_index = class_file.constant_pool.add_utf8("InnerClasses")?;
287+
class_file.attributes.push(Attribute::InnerClasses {
288+
name_index: inner_classes_attr_name_index,
289+
classes: inner_classes_vec,
290+
});
291+
}
292+
247293
// --- Add SourceFile Attribute ---
248294
let simple_name = class_name_jvm.split('/').last().unwrap_or(class_name_jvm);
249295
let source_file_name = format!("{}.rs", simple_name);
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../../config.toml

tests/integration/inner_classes/Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
cargo-features = ["profile-rustflags"]
2+
3+
[package]
4+
name = "inner_classes"
5+
version = "0.1.0"
6+
edition = "2024"
7+
8+
[dependencies]
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
public class Main {
2+
public static void main(String[] args) {
3+
TrafficLight red = new TrafficLight.Red();
4+
TrafficLight yellow = new TrafficLight.Yellow();
5+
TrafficLight green = new TrafficLight.Green();
6+
7+
String redAction = inner_classes.get_light_action(red);
8+
if (!redAction.equals("Stop")) {
9+
throw new AssertionError("Test failed for Red: expected 'Stop' but got '" + redAction + "'");
10+
}
11+
12+
String yellowAction = inner_classes.get_light_action(yellow);
13+
if (!yellowAction.equals("Caution")) {
14+
throw new AssertionError("Test failed for Yellow: expected 'Caution' but got '" + yellowAction + "'");
15+
}
16+
17+
String greenAction = inner_classes.get_light_action(green);
18+
if (!greenAction.equals("Go")) {
19+
throw new AssertionError("Test failed for Green: expected 'Go' but got '" + greenAction + "'");
20+
}
21+
22+
System.out.println("Inner class access test passed!");
23+
}
24+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
enum variants we want to test will get optimised away.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
pub enum TrafficLight {
2+
Red,
3+
Yellow,
4+
Green,
5+
}
6+
7+
pub fn get_light_action(light: TrafficLight) -> &'static str {
8+
match light {
9+
TrafficLight::Red => "Stop",
10+
TrafficLight::Yellow => "Caution",
11+
TrafficLight::Green => "Go",
12+
}
13+
}
14+
15+
pub fn main() {
16+
assert!(get_light_action(TrafficLight::Red) == "Stop");
17+
assert!(get_light_action(TrafficLight::Yellow) == "Caution");
18+
assert!(get_light_action(TrafficLight::Green) == "Go");
19+
}

0 commit comments

Comments
 (0)