Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions riichienv-core/src/observation/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ mod tests {
[None; 4], // riichi_sutehais
[None; 4], // last_tedashis
None, // last_discard
None, // drawn_tile
)
}

Expand Down
194 changes: 194 additions & 0 deletions riichienv-core/src/observation/mjai_select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
//! Shared helpers for mapping Mjai messages to a legal `Action`.
//!
//! Used by both `Observation::select_action_from_mjai` (4P) and
//! `Observation3P::select_action_from_mjai` (3P).

use pyo3::prelude::*;
use pyo3::types::{PyDict, PyDictMethods};

use crate::action::{Action, ActionType};
use crate::parser::tid_to_mjai;

pub(crate) struct ParsedMjai {
pub type_str: String,
pub tile_str: String,
pub tsumogiri: Option<bool>,
pub consumed: Option<Vec<String>>,
}

pub(crate) fn parse_mjai_message(mjai_data: &Bound<'_, PyAny>) -> Option<ParsedMjai> {
if let Ok(s) = mjai_data.extract::<String>() {
let v: serde_json::Value = serde_json::from_str(&s).ok()?;
let type_str = v["type"].as_str()?.to_string();
let tile_str = v["pai"].as_str().unwrap_or("").to_string();
let tsumogiri = v.get("tsumogiri").and_then(|x| x.as_bool());
let consumed = v.get("consumed").and_then(|x| x.as_array()).map(|arr| {
arr.iter()
.filter_map(|e| e.as_str().map(|s| s.to_string()))
.collect::<Vec<_>>()
});
Some(ParsedMjai {
type_str,
tile_str,
tsumogiri,
consumed,
})
} else if let Ok(dict) = mjai_data.cast::<PyDict>() {
let type_str: String = dict
.get_item("type")
.ok()
.flatten()
.and_then(|x| x.extract::<String>().ok())
.unwrap_or_default();
let tile_str: String = dict
.get_item("pai")
.ok()
.flatten()
.or_else(|| dict.get_item("tile").ok().flatten())
.and_then(|x| x.extract::<String>().ok())
.unwrap_or_default();
let tsumogiri = dict
.get_item("tsumogiri")
.ok()
.flatten()
.and_then(|x| x.extract::<bool>().ok());
let consumed = dict
.get_item("consumed")
.ok()
.flatten()
.and_then(|x| x.extract::<Vec<String>>().ok());
Some(ParsedMjai {
type_str,
tile_str,
tsumogiri,
consumed,
})
} else {
None
}
}

fn consumed_matches(action_consume: &[u8], expected: &[String]) -> bool {
if action_consume.len() != expected.len() {
return false;
}
let mut a: Vec<String> = action_consume.iter().map(|&t| tid_to_mjai(t)).collect();
let mut b: Vec<String> = expected.to_vec();
a.sort();
b.sort();
a == b
}

/// Select a matching `Action` from a slice of legal actions for a parsed Mjai
/// message.
///
/// `three_player` controls whether 3P-only types (`kita`) are recognized; chi
/// is rejected when set.
pub(crate) fn select_action<'a>(
legal_actions: &'a [Action],
parsed: &ParsedMjai,
drawn_tile: Option<u8>,
three_player: bool,
) -> Option<&'a Action> {
let atype = parsed.type_str.as_str();

if atype == "hora" {
return legal_actions
.iter()
.find(|a| matches!(a.action_type, ActionType::Tsumo | ActionType::Ron));
}

if atype == "none" {
return legal_actions
.iter()
.find(|a| a.action_type == ActionType::Pass);
}

let target_type = match atype {
"dahai" => Some(ActionType::Discard),
"chi" if !three_player => Some(ActionType::Chi),
"pon" => Some(ActionType::Pon),
"kakan" => Some(ActionType::Kakan),
"daiminkan" => Some(ActionType::Daiminkan),
"ankan" => Some(ActionType::Ankan),
"kita" if three_player => Some(ActionType::Kita),
"reach" => Some(ActionType::Riichi),
"ryukyoku" => Some(ActionType::KyushuKyuhai),
_ => None,
};

let tt = target_type?;

// Special-case Discard: filter by mjai pai (or any Discard if pai is
// omitted) then disambiguate via tsumogiri.
//
// NOTE: An mjai `dahai` message without a `pai` field is malformed per
// the protocol, but we still return a non-empty Action (the first
// legal Discard) instead of `None` to preserve backward compatibility
// with the previous implementation; bailing out here would silently
// break callers that rely on the old lenient behavior.
if tt == ActionType::Discard {
let candidates: Vec<&Action> = legal_actions
.iter()
.filter(|a| {
a.action_type == ActionType::Discard
&& (parsed.tile_str.is_empty()
|| a.tile.is_some_and(|t| tid_to_mjai(t) == parsed.tile_str))
})
.collect();

if candidates.is_empty() {
return None;
}

if let (Some(tsumogiri), Some(drawn)) = (parsed.tsumogiri, drawn_tile) {
let preferred = candidates.iter().find(|a| {
let is_drawn = a.tile == Some(drawn);
if tsumogiri { is_drawn } else { !is_drawn }
});
if let Some(a) = preferred {
return Some(*a);
}
}

return Some(candidates[0]);
}

legal_actions.iter().find(|a| {
if a.action_type != tt {
return false;
}

if let Some(consumed) = parsed.consumed.as_ref() {
if !consumed_matches(&a.consume_tiles, consumed) {
return false;
}
// If pai is also given, double-check tile match for actions that
// carry a meaningful tile (chi/pon/daiminkan/kakan).
if !parsed.tile_str.is_empty()
&& matches!(
tt,
ActionType::Chi | ActionType::Pon | ActionType::Daiminkan | ActionType::Kakan
)
{
if let Some(t) = a.tile {
if tid_to_mjai(t) != parsed.tile_str {
return false;
}
} else {
return false;
}
}
return true;
}

// No consumed field: fall back to pai-based match.
if !parsed.tile_str.is_empty() {
if let Some(t) = a.tile {
return tid_to_mjai(t) == parsed.tile_str;
}
return false;
}
true
})
}
6 changes: 6 additions & 0 deletions riichienv-core/src/observation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ mod encode;
#[cfg(feature = "python")]
pub(crate) mod helpers;
#[cfg(feature = "python")]
pub(crate) mod mjai_select;
#[cfg(feature = "python")]
mod python;
#[cfg(feature = "python")]
pub(crate) mod sequence_features;
Expand Down Expand Up @@ -49,6 +51,8 @@ pub struct Observation {
pub riichi_sutehais: [Option<u8>; 4],
pub last_tedashis: [Option<u8>; 4],
pub last_discard: Option<u32>,
#[serde(default)]
pub drawn_tile: Option<u8>,
}

/// Pure Rust methods (no PyO3 dependency).
Expand All @@ -74,6 +78,7 @@ impl Observation {
riichi_sutehais: [Option<u8>; 4],
last_tedashis: [Option<u8>; 4],
last_discard: Option<u32>,
drawn_tile: Option<u8>,
) -> Self {
let hands_u32 = hands.map(|h| h.into_iter().map(|x| x as u32).collect());
let discards_u32 = discards.map(|d| d.into_iter().map(|x| x as u32).collect());
Expand Down Expand Up @@ -101,6 +106,7 @@ impl Observation {
riichi_sutehais,
last_tedashis,
last_discard,
drawn_tile,
}
}

Expand Down
101 changes: 6 additions & 95 deletions riichienv-core/src/observation/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::helpers::get_next_tile;
impl Observation {
#[new]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (player_id, hands, melds, discards, dora_indicators, scores, riichi_declared, legal_actions, events, honba, riichi_sticks, round_wind, oya, kyoku_index, waits, is_tenpai, riichi_sutehais, last_tedashis, last_discard))]
#[pyo3(signature = (player_id, hands, melds, discards, dora_indicators, scores, riichi_declared, legal_actions, events, honba, riichi_sticks, round_wind, oya, kyoku_index, waits, is_tenpai, riichi_sutehais, last_tedashis, last_discard, drawn_tile=None))]
pub fn py_new(
player_id: u8,
hands: Vec<Vec<u8>>,
Expand All @@ -35,6 +35,7 @@ impl Observation {
riichi_sutehais: Vec<Option<u8>>,
last_tedashis: Vec<Option<u8>>,
last_discard: Option<u32>,
drawn_tile: Option<u8>,
) -> Self {
let hands: [Vec<u8>; 4] = hands.try_into().expect("expected 4 hands");
let melds: [Vec<Meld>; 4] = melds.try_into().expect("expected 4 melds");
Expand Down Expand Up @@ -68,6 +69,7 @@ impl Observation {
riichi_sutehais,
last_tedashis,
last_discard,
drawn_tile,
)
}

Expand Down Expand Up @@ -120,100 +122,9 @@ impl Observation {

#[pyo3(signature = (mjai_data))]
pub fn select_action_from_mjai(&self, mjai_data: &Bound<'_, PyAny>) -> Option<Action> {
let (atype, tile_str) = if let Ok(s) = mjai_data.extract::<String>() {
let v: serde_json::Value = serde_json::from_str(&s).ok()?;
(
v["type"].as_str()?.to_string(),
v["pai"].as_str().unwrap_or("").to_string(),
)
} else if let Ok(dict) = mjai_data.cast::<PyDict>() {
let type_str: String = dict
.get_item("type")
.ok()
.flatten()
.and_then(|x| x.extract::<String>().ok())
.unwrap_or_default();
let _args_list: Vec<String> = dict
.get_item("args")
.ok()
.flatten()
.and_then(|x| x.extract::<Vec<String>>().ok())
.unwrap_or_default();
let _who: i8 = dict
.get_item("who")
.ok()
.flatten()
.and_then(|x| x.extract::<i8>().ok())
.unwrap_or(-1);
let tile_str: String = dict
.get_item("pai")
.ok()
.flatten()
.or_else(|| dict.get_item("tile").ok().flatten())
.and_then(|x| x.extract::<String>().ok())
.unwrap_or_default();
(type_str, tile_str)
} else {
return None;
};

let target_type = match atype.as_str() {
"dahai" => Some(crate::action::ActionType::Discard),
"chi" => Some(crate::action::ActionType::Chi),
"pon" => Some(crate::action::ActionType::Pon),
"kakan" => Some(crate::action::ActionType::Kakan),
"daiminkan" => Some(crate::action::ActionType::Daiminkan),
"ankan" => Some(crate::action::ActionType::Ankan),
"reach" => Some(crate::action::ActionType::Riichi),
"hora" => None,
"ryukyoku" => Some(crate::action::ActionType::KyushuKyuhai),
_ => None,
};

if atype == "hora" {
return self
._legal_actions
.iter()
.find(|a| {
a.action_type == crate::action::ActionType::Tsumo
|| a.action_type == crate::action::ActionType::Ron
})
.cloned();
}

if let Some(tt) = target_type {
return self
._legal_actions
.iter()
.find(|a| {
if a.action_type != tt {
return false;
}
if !tile_str.is_empty() {
if let Some(t) = a.tile {
let t_str = crate::parser::tid_to_mjai(t);
if t_str == tile_str {
return true;
}
return false;
} else {
return false;
}
}
true
})
.cloned();
}

if atype == "none" {
return self
._legal_actions
.iter()
.find(|a| a.action_type == crate::action::ActionType::Pass)
.cloned();
}

None
use super::mjai_select::{parse_mjai_message, select_action};
let parsed = parse_mjai_message(mjai_data)?;
select_action(&self._legal_actions, &parsed, self.drawn_tile, false).cloned()
}

#[pyo3(name = "new_events")]
Expand Down
1 change: 1 addition & 0 deletions riichienv-core/src/observation_3p/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ mod tests {
[None; 3], // riichi_sutehais
[None; 3], // last_tedashis
None, // last_discard
None, // drawn_tile
)
}

Expand Down
4 changes: 4 additions & 0 deletions riichienv-core/src/observation_3p/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub struct Observation3P {
pub riichi_sutehais: [Option<u8>; 3],
pub last_tedashis: [Option<u8>; 3],
pub last_discard: Option<u32>,
#[serde(default)]
pub drawn_tile: Option<u8>,
}

/// Pure Rust methods (no PyO3 dependency).
Expand All @@ -66,6 +68,7 @@ impl Observation3P {
riichi_sutehais: [Option<u8>; 3],
last_tedashis: [Option<u8>; 3],
last_discard: Option<u32>,
drawn_tile: Option<u8>,
) -> Self {
let hands_u32 = hands.map(|h| h.into_iter().map(|x| x as u32).collect());
let discards_u32 = discards.map(|d| d.into_iter().map(|x| x as u32).collect());
Expand Down Expand Up @@ -95,6 +98,7 @@ impl Observation3P {
riichi_sutehais,
last_tedashis,
last_discard,
drawn_tile,
}
}

Expand Down
Loading
Loading