Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 58 additions & 44 deletions UnityMcpBridge/UnityMcpServer~/src/unity_connection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import socket
import contextlib
import errno
import json
import logging
import random
import socket
import struct
import threading
import time
from dataclasses import dataclass
from pathlib import Path
import time
import random
import errno
from typing import Dict, Any
from typing import Any, Dict
from config import config
from port_discovery import PortDiscovery

Expand All @@ -30,6 +32,7 @@ def __post_init__(self):
"""Set port from discovery if not explicitly provided"""
if self.port is None:
self.port = PortDiscovery.discover_unity_port()
self._io_lock = threading.Lock()

def connect(self) -> bool:
"""Establish a connection to the Unity Editor."""
Expand All @@ -42,20 +45,24 @@ def connect(self) -> bool:

# Strict handshake: require FRAMING=1
try:
self.sock.settimeout(1.0)
require_framing = getattr(config, "require_framing", True)
self.sock.settimeout(getattr(config, "handshake_timeout", 1.0))
greeting = self.sock.recv(256)
text = greeting.decode('ascii', errors='ignore') if greeting else ''
if 'FRAMING=1' in text:
self.use_framing = True
logger.debug('Unity MCP handshake received: FRAMING=1 (strict)')
else:
try:
msg = b'Unity MCP requires FRAMING=1'
header = struct.pack('>Q', len(msg))
self.sock.sendall(header + msg)
except Exception:
pass
raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}')
if require_framing:
# Best-effort advisory; peer may ignore if not framed-capable
with contextlib.suppress(Exception):
msg = b'Unity MCP requires FRAMING=1'
header = struct.pack('>Q', len(msg))
self.sock.sendall(header + msg)
Comment on lines +58 to +61
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The framed message sent here may not be properly handled by a non-framing peer, potentially causing connection issues. Consider logging when this advisory message is sent.

raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}')
else:
self.use_framing = False
logger.warning('Unity MCP handshake missing FRAMING=1; proceeding in legacy mode by configuration')
finally:
self.sock.settimeout(config.connection_timeout)
return True
Expand Down Expand Up @@ -101,9 +108,9 @@ def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes:
payload = self._read_exact(sock, payload_len)
logger.info(f"Received framed response ({len(payload)} bytes)")
return payload
except socket.timeout:
except socket.timeout as e:
logger.warning("Socket timeout during framed receive")
raise Exception("Timeout receiving Unity response")
raise TimeoutError("Timeout receiving Unity response") from e
except Exception as e:
logger.error(f"Error during framed receive: {str(e)}")
raise
Expand Down Expand Up @@ -201,10 +208,9 @@ def read_status_file() -> dict | None:

for attempt in range(attempts + 1):
try:
# Ensure connected (perform handshake each time so framing stays correct)
if not self.sock:
if not self.connect():
raise Exception("Could not connect to Unity")
# Ensure connected (handshake occurs within connect())
if not self.sock and not self.connect():
raise Exception("Could not connect to Unity")

# Build payload
if command_type == 'ping':
Expand All @@ -213,31 +219,39 @@ def read_status_file() -> dict | None:
command = {"type": command_type, "params": params or {}}
payload = json.dumps(command, ensure_ascii=False).encode('utf-8')

# Send
try:
logger.debug(f"send {len(payload)} bytes; mode={'framed' if self.use_framing else 'legacy'}; head={(payload[:32]).decode('utf-8','ignore')}")
except Exception:
pass
if self.use_framing:
header = struct.pack('>Q', len(payload))
self.sock.sendall(header)
self.sock.sendall(payload)
else:
self.sock.sendall(payload)

# During retry bursts use a short receive timeout
if attempt > 0 and last_short_timeout is None:
last_short_timeout = self.sock.gettimeout()
self.sock.settimeout(1.0)
response_data = self.receive_full_response(self.sock)
try:
logger.debug(f"recv {len(response_data)} bytes; mode={'framed' if self.use_framing else 'legacy'}; head={(response_data[:32]).decode('utf-8','ignore')}")
except Exception:
pass
# restore steady-state timeout if changed
if last_short_timeout is not None:
self.sock.settimeout(config.connection_timeout)
last_short_timeout = None
# Send/receive are serialized to protect the shared socket
with self._io_lock:
mode = 'framed' if self.use_framing else 'legacy'
with contextlib.suppress(Exception):
logger.debug(
"send %d bytes; mode=%s; head=%s",
len(payload),
mode,
(payload[:32]).decode('utf-8', 'ignore'),
)
if self.use_framing:
header = struct.pack('>Q', len(payload))
self.sock.sendall(header)
self.sock.sendall(payload)
else:
self.sock.sendall(payload)

# During retry bursts use a short receive timeout
if attempt > 0 and last_short_timeout is None:
last_short_timeout = self.sock.gettimeout()
self.sock.settimeout(1.0)
response_data = self.receive_full_response(self.sock)
with contextlib.suppress(Exception):
logger.debug(
"recv %d bytes; mode=%s; head=%s",
len(response_data),
mode,
(response_data[:32]).decode('utf-8', 'ignore'),
)
# restore steady-state timeout if changed
if last_short_timeout is not None:
self.sock.settimeout(last_short_timeout)
last_short_timeout = None

# Parse
if command_type == 'ping':
Expand Down
22 changes: 17 additions & 5 deletions tests/test_logging_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,24 @@ def visit_Call(self, node: ast.Call):
# print(...)
if isinstance(node.func, ast.Name) and node.func.id == "print":
self.hit = True
# builtins.print(...)
elif (
isinstance(node.func, ast.Attribute)
and node.func.attr == "print"
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "builtins"
):
self.hit = True
# sys.stdout.write(...)
if isinstance(node.func, ast.Attribute) and node.func.attr == "write":
val = node.func.value
if isinstance(val, ast.Attribute) and val.attr == "stdout":
if isinstance(val.value, ast.Name) and val.value.id == "sys":
self.hit = True
if (
isinstance(node.func, ast.Attribute)
and node.func.attr == "write"
and isinstance(node.func.value, ast.Attribute)
and node.func.value.attr == "stdout"
and isinstance(node.func.value.value, ast.Name)
and node.func.value.value.id == "sys"
):
self.hit = True
self.generic_visit(node)

v = StdoutVisitor()
Expand Down
1 change: 1 addition & 0 deletions tests/test_transport_framing.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_unframed_data_disconnect():
port = start_handshake_enforcing_server()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", port))
sock.settimeout(1.0)
sock.sendall(b"BAD")
time.sleep(0.4)
try:
Expand Down