Skip to content

Commit 11b29cc

Browse files
authored
Merge pull request myugan#21 from myugan/fix/port-forwarding-logic
fix: update port forwarding logic
2 parents 41237f8 + 96335f5 commit 11b29cc

2 files changed

Lines changed: 150 additions & 104 deletions

File tree

firecracker/microvm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,7 @@ def _remove_port_forwarding(self, host_ports, dest_ports, vmm_id=None, update_co
12481248
dest_ports_list = [dest_ports] if isinstance(dest_ports, int) else dest_ports
12491249

12501250
for host_port, dest_port in zip(host_ports_list, dest_ports_list):
1251-
self._network.delete_port_forward(vmm_id, host_port)
1251+
self._network.delete_port_forward(vmm_id, host_port, dest_port)
12521252
if self._config.verbose:
12531253
self._logger.debug(f"Removed {host_port} -> {dest_port} from VMM {vmm_id}")
12541254

firecracker/network.py

Lines changed: 149 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,9 @@ def get_port_forward_handles(self, host_ip: str, host_port: int, dest_ip: str, d
476476
if 'masquerade' in e:
477477
has_masquerade = True
478478

479-
if has_saddr_match and has_masquerade and comment == f"machine_id={id}":
479+
# Note: This function is not currently used, but if it were, it would need an 'id' parameter
480+
# For now, we'll just check for masquerade rules without machine_id matching
481+
if has_saddr_match and has_masquerade:
480482
if self._config.verbose:
481483
self._logger.debug(f"Found matching postrouting masquerade rule {rule}")
482484
self._logger.info(f"Found postrouting rule with handle {rule['handle']}")
@@ -490,15 +492,16 @@ def get_port_forward_handles(self, host_ip: str, host_port: int, dest_ip: str, d
490492
except Exception as e:
491493
raise NetworkError(f"Failed to get nftables rules: {str(e)}")
492494

493-
def get_port_forward_by_comment(self, id: str, host_port: int):
495+
def get_port_forward_by_comment(self, id: str, host_port: int, dest_port: int):
494496
"""Get port forwarding rules by matching the comment pattern.
495497
496498
Args:
497499
id (str): Machine ID to search for
498500
host_port (int): Host port to search for
501+
dest_port (int): Destination port to search for
499502
500503
Returns:
501-
dict: Dictionary containing handles for prerouting and postrouting rules.
504+
dict: Dictionary containing handles for prerouting rules only.
502505
503506
Raises:
504507
NetworkError: If retrieving nftables rules fails.
@@ -511,8 +514,8 @@ def get_port_forward_by_comment(self, id: str, host_port: int):
511514
output = self._nft.json_cmd(list_cmd)
512515
result = output[1]['nftables']
513516
rules = {}
514-
prerouting_comment = f"machine_id={id} host_port={host_port}"
515-
postrouting_comment = f"machine_id={id}"
517+
518+
prerouting_comment = f"machine_id={id} host_port={host_port} vm_port={dest_port}"
516519

517520
for item in result:
518521
if 'rule' not in item:
@@ -522,30 +525,61 @@ def get_port_forward_by_comment(self, id: str, host_port: int):
522525
chain = rule.get('chain', '').upper() # Normalize chain name to uppercase
523526
comment = rule.get('comment', '')
524527

525-
# Check for PREROUTING rules with matching comment
528+
# Check for PREROUTING rules with matching comment only
526529
if rule.get('family') == 'ip' and rule.get('table') == 'nat' and chain == 'PREROUTING':
527530
if comment == prerouting_comment:
528531
if self._config.verbose:
529532
self._logger.info(f"Found prerouting rule with matching comment: {comment}")
530533
self._logger.debug(f"Rule details: {rule}")
531534
rules['prerouting'] = rule['handle']
532535

533-
# Check for POSTROUTING rules with matching comment
534-
elif rule.get('family') == 'ip' and rule.get('table') == 'nat' and chain == 'POSTROUTING':
535-
if comment == postrouting_comment:
536-
if self._config.verbose:
537-
self._logger.info(f"Found postrouting rule with matching comment: {comment}")
538-
self._logger.debug(f"Rule details: {rule}")
539-
rules['postrouting'] = rule['handle']
540-
541536
if not rules and self._config.verbose:
542-
self._logger.info(f"No port forwarding rules found for machine_id={id} host_port={host_port}")
537+
self._logger.info(f"No port forwarding rules found for machine_id={id} host_port={host_port} vm_port={dest_port}")
543538

544539
return rules
545540

546541
except Exception as e:
547542
raise NetworkError(f"Failed to get nftables rules: {str(e)}")
548543

544+
def _check_postrouting_exists(self, id: str) -> bool:
545+
"""Check if a POSTROUTING rule already exists for the given machine ID.
546+
547+
Args:
548+
id (str): Machine ID to check for
549+
550+
Returns:
551+
bool: True if POSTROUTING rule exists, False otherwise
552+
"""
553+
try:
554+
list_cmd = {"nftables": [{"list": {"table": {"family": "ip", "name": "nat"}}}]}
555+
output = self._nft.json_cmd(list_cmd)
556+
result = output[1]['nftables']
557+
558+
postrouting_comment = f"machine_id={id}"
559+
560+
for item in result:
561+
if 'rule' not in item:
562+
continue
563+
564+
rule = item['rule']
565+
chain = rule.get('chain', '').upper()
566+
comment = rule.get('comment', '')
567+
568+
if (rule.get('family') == 'ip' and
569+
rule.get('table') == 'nat' and
570+
chain == 'POSTROUTING' and
571+
comment == postrouting_comment):
572+
if self._config.verbose:
573+
self._logger.debug(f"Found existing POSTROUTING rule for machine_id={id}")
574+
return True
575+
576+
return False
577+
578+
except Exception as e:
579+
if self._config.verbose:
580+
self._logger.warn(f"Failed to check for existing POSTROUTING rule: {str(e)}")
581+
return False
582+
549583
def add_port_forward(self, id: str, host_ip: str, host_port: int, dest_ip: str, dest_port: int, protocol: str = "tcp"):
550584
"""Port forward a port to a new IP and port.
551585
@@ -559,14 +593,15 @@ def add_port_forward(self, id: str, host_ip: str, host_port: int, dest_ip: str,
559593
Raises:
560594
NetworkError: If adding nftables port forwarding rule fails.
561595
"""
562-
# First check if the rules already exist
563-
# existing_rules = self.get_port_forward_handles(host_ip, host_port, dest_ip, dest_port)
564-
565-
existing_rules = self.get_port_forward_by_comment(id, host_port)
596+
# First check if the PREROUTING rule already exists
597+
existing_rules = self.get_port_forward_by_comment(id, host_port, dest_port)
566598
if existing_rules:
567599
if self._config.verbose:
568600
self._logger.info("Port forwarding rules already exist")
569-
return
601+
return True
602+
603+
# Check if POSTROUTING rule already exists
604+
postrouting_exists = self._check_postrouting_exists(id)
570605

571606
# Create the rules
572607
rules = {
@@ -591,96 +626,104 @@ def add_port_forward(self, id: str, host_ip: str, host_port: int, dest_ip: str,
591626
"policy": "accept"
592627
}
593628
}
594-
},
595-
{
596-
"add": {
597-
"chain": {
598-
"family": "ip",
599-
"table": "nat",
600-
"name": "POSTROUTING",
601-
"type": "nat",
602-
"hook": "postrouting",
603-
"prio": 100,
604-
"policy": "accept"
605-
}
629+
}
630+
]
631+
}
632+
633+
# Only add POSTROUTING chain if it doesn't exist
634+
if not postrouting_exists:
635+
rules["nftables"].append({
636+
"add": {
637+
"chain": {
638+
"family": "ip",
639+
"table": "nat",
640+
"name": "POSTROUTING",
641+
"type": "nat",
642+
"hook": "postrouting",
643+
"prio": 100,
644+
"policy": "accept"
606645
}
607-
},
608-
{
609-
"add": {
610-
"rule": {
611-
"family": "ip",
612-
"table": "nat",
613-
"chain": "PREROUTING",
614-
"comment": f"machine_id={id} host_port={host_port}",
615-
"expr": [
616-
{
617-
"match": {
618-
"op": "==",
619-
"left": {
620-
"payload": {
621-
"protocol": "ip",
622-
"field": "daddr"
623-
}
624-
},
625-
"right": host_ip
646+
}
647+
})
648+
649+
# Add PREROUTING rule
650+
rules["nftables"].append({
651+
"add": {
652+
"rule": {
653+
"family": "ip",
654+
"table": "nat",
655+
"chain": "PREROUTING",
656+
"comment": f"machine_id={id} host_port={host_port} vm_port={dest_port}",
657+
"expr": [
658+
{
659+
"match": {
660+
"op": "==",
661+
"left": {
662+
"payload": {
663+
"protocol": "ip",
664+
"field": "daddr"
626665
}
627666
},
628-
{
629-
"match": {
630-
"op": "==",
631-
"left": {
632-
"payload": {
633-
"protocol": protocol,
634-
"field": "dport"
635-
}
636-
},
637-
"right": host_port
667+
"right": host_ip
668+
}
669+
},
670+
{
671+
"match": {
672+
"op": "==",
673+
"left": {
674+
"payload": {
675+
"protocol": protocol,
676+
"field": "dport"
638677
}
639678
},
640-
{
641-
"dnat": {
642-
"addr": dest_ip,
643-
"port": dest_port
644-
}
645-
}
646-
]
679+
"right": host_port
680+
}
681+
},
682+
{
683+
"dnat": {
684+
"addr": dest_ip,
685+
"port": dest_port
686+
}
647687
}
648-
}
649-
},
650-
{
651-
"add": {
652-
"rule": {
653-
"family": "ip",
654-
"table": "nat",
655-
"chain": "POSTROUTING",
656-
"comment": f"machine_id={id}",
657-
"expr": [
658-
{
659-
"match": {
660-
"op": "==",
661-
"left": {
662-
"payload": {
663-
"protocol": "ip",
664-
"field": "saddr"
665-
}
666-
},
667-
"right": {
668-
"prefix": {
669-
"addr": dest_ip,
670-
"len": 32
671-
}
688+
]
689+
}
690+
}
691+
})
692+
693+
# Only add POSTROUTING rule if it doesn't already exist
694+
if not postrouting_exists:
695+
rules["nftables"].append({
696+
"add": {
697+
"rule": {
698+
"family": "ip",
699+
"table": "nat",
700+
"chain": "POSTROUTING",
701+
"comment": f"machine_id={id}",
702+
"expr": [
703+
{
704+
"match": {
705+
"op": "==",
706+
"left": {
707+
"payload": {
708+
"protocol": "ip",
709+
"field": "saddr"
710+
}
711+
},
712+
"right": {
713+
"prefix": {
714+
"addr": dest_ip,
715+
"len": 32
672716
}
673717
}
674-
},
675-
{
676-
"masquerade": None
677718
}
678-
]
679-
}
719+
},
720+
{
721+
"masquerade": None
722+
}
723+
]
680724
}
681725
}
682-
]
683-
}
726+
})
684727

685728
try:
686729
for rule in rules["nftables"]:
@@ -764,12 +807,13 @@ def delete_masquerade(self):
764807
except Exception as e:
765808
raise NetworkError(f"Failed to delete masquerade rule: {str(e)}")
766809

767-
def delete_port_forward(self, id: str, host_port: int):
810+
def delete_port_forward(self, id: str, host_port: int, dest_port: int):
768811
"""Delete port forwarding rules.
769812
770813
Args:
771-
host_port (int): Port being forwarded.
772-
machine_id (str): ID of the machine for which port forwarding is being deleted.
814+
id (str): Machine ID for which port forwarding is being deleted.
815+
host_port (int): Host port being forwarded.
816+
dest_port (int): Destination port being forwarded to.
773817
774818
Raises:
775819
NetworkError: If deleting port forwarding rules fails.
@@ -791,7 +835,9 @@ def delete_port_forward(self, id: str, host_port: int):
791835
rule = item['rule']
792836
comment = rule.get('comment', '')
793837

794-
if f"machine_id={id} host_port={host_port}" in comment:
838+
comment_matches = f"machine_id={id} host_port={host_port} vm_port={dest_port}" in comment
839+
840+
if comment_matches:
795841
chain = rule.get('chain', '').upper()
796842
handle = rule['handle']
797843

0 commit comments

Comments
 (0)