@@ -476,7 +476,9 @@ def get_port_forward_handles(self, host_ip: str, host_port: int, dest_ip: str, d
476476 if 'masquerade' in e :
477477 has_masquerade = True
478478
479- if has_saddr_match and has_masquerade and comment == f"machine_id={ id } " :
479+ # Note: This function is not currently used, but if it were, it would need an 'id' parameter
480+ # For now, we'll just check for masquerade rules without machine_id matching
481+ if has_saddr_match and has_masquerade :
480482 if self ._config .verbose :
481483 self ._logger .debug (f"Found matching postrouting masquerade rule { rule } " )
482484 self ._logger .info (f"Found postrouting rule with handle { rule ['handle' ]} " )
@@ -490,15 +492,16 @@ def get_port_forward_handles(self, host_ip: str, host_port: int, dest_ip: str, d
490492 except Exception as e :
491493 raise NetworkError (f"Failed to get nftables rules: { str (e )} " )
492494
493- def get_port_forward_by_comment (self , id : str , host_port : int ):
495+ def get_port_forward_by_comment (self , id : str , host_port : int , dest_port : int ):
494496 """Get port forwarding rules by matching the comment pattern.
495497
496498 Args:
497499 id (str): Machine ID to search for
498500 host_port (int): Host port to search for
501+ dest_port (int): Destination port to search for
499502
500503 Returns:
501- dict: Dictionary containing handles for prerouting and postrouting rules.
504+ dict: Dictionary containing handles for prerouting rules only .
502505
503506 Raises:
504507 NetworkError: If retrieving nftables rules fails.
@@ -511,8 +514,8 @@ def get_port_forward_by_comment(self, id: str, host_port: int):
511514 output = self ._nft .json_cmd (list_cmd )
512515 result = output [1 ]['nftables' ]
513516 rules = {}
514- prerouting_comment = f"machine_id= { id } host_port= { host_port } "
515- postrouting_comment = f"machine_id={ id } "
517+
518+ prerouting_comment = f"machine_id={ id } host_port= { host_port } vm_port= { dest_port } "
516519
517520 for item in result :
518521 if 'rule' not in item :
@@ -522,30 +525,61 @@ def get_port_forward_by_comment(self, id: str, host_port: int):
522525 chain = rule .get ('chain' , '' ).upper () # Normalize chain name to uppercase
523526 comment = rule .get ('comment' , '' )
524527
525- # Check for PREROUTING rules with matching comment
528+ # Check for PREROUTING rules with matching comment only
526529 if rule .get ('family' ) == 'ip' and rule .get ('table' ) == 'nat' and chain == 'PREROUTING' :
527530 if comment == prerouting_comment :
528531 if self ._config .verbose :
529532 self ._logger .info (f"Found prerouting rule with matching comment: { comment } " )
530533 self ._logger .debug (f"Rule details: { rule } " )
531534 rules ['prerouting' ] = rule ['handle' ]
532535
533- # Check for POSTROUTING rules with matching comment
534- elif rule .get ('family' ) == 'ip' and rule .get ('table' ) == 'nat' and chain == 'POSTROUTING' :
535- if comment == postrouting_comment :
536- if self ._config .verbose :
537- self ._logger .info (f"Found postrouting rule with matching comment: { comment } " )
538- self ._logger .debug (f"Rule details: { rule } " )
539- rules ['postrouting' ] = rule ['handle' ]
540-
541536 if not rules and self ._config .verbose :
542- self ._logger .info (f"No port forwarding rules found for machine_id={ id } host_port={ host_port } " )
537+ self ._logger .info (f"No port forwarding rules found for machine_id={ id } host_port={ host_port } vm_port= { dest_port } " )
543538
544539 return rules
545540
546541 except Exception as e :
547542 raise NetworkError (f"Failed to get nftables rules: { str (e )} " )
548543
544+ def _check_postrouting_exists (self , id : str ) -> bool :
545+ """Check if a POSTROUTING rule already exists for the given machine ID.
546+
547+ Args:
548+ id (str): Machine ID to check for
549+
550+ Returns:
551+ bool: True if POSTROUTING rule exists, False otherwise
552+ """
553+ try :
554+ list_cmd = {"nftables" : [{"list" : {"table" : {"family" : "ip" , "name" : "nat" }}}]}
555+ output = self ._nft .json_cmd (list_cmd )
556+ result = output [1 ]['nftables' ]
557+
558+ postrouting_comment = f"machine_id={ id } "
559+
560+ for item in result :
561+ if 'rule' not in item :
562+ continue
563+
564+ rule = item ['rule' ]
565+ chain = rule .get ('chain' , '' ).upper ()
566+ comment = rule .get ('comment' , '' )
567+
568+ if (rule .get ('family' ) == 'ip' and
569+ rule .get ('table' ) == 'nat' and
570+ chain == 'POSTROUTING' and
571+ comment == postrouting_comment ):
572+ if self ._config .verbose :
573+ self ._logger .debug (f"Found existing POSTROUTING rule for machine_id={ id } " )
574+ return True
575+
576+ return False
577+
578+ except Exception as e :
579+ if self ._config .verbose :
580+ self ._logger .warn (f"Failed to check for existing POSTROUTING rule: { str (e )} " )
581+ return False
582+
549583 def add_port_forward (self , id : str , host_ip : str , host_port : int , dest_ip : str , dest_port : int , protocol : str = "tcp" ):
550584 """Port forward a port to a new IP and port.
551585
@@ -559,14 +593,15 @@ def add_port_forward(self, id: str, host_ip: str, host_port: int, dest_ip: str,
559593 Raises:
560594 NetworkError: If adding nftables port forwarding rule fails.
561595 """
562- # First check if the rules already exist
563- # existing_rules = self.get_port_forward_handles(host_ip, host_port, dest_ip, dest_port)
564-
565- existing_rules = self .get_port_forward_by_comment (id , host_port )
596+ # First check if the PREROUTING rule already exists
597+ existing_rules = self .get_port_forward_by_comment (id , host_port , dest_port )
566598 if existing_rules :
567599 if self ._config .verbose :
568600 self ._logger .info ("Port forwarding rules already exist" )
569- return
601+ return True
602+
603+ # Check if POSTROUTING rule already exists
604+ postrouting_exists = self ._check_postrouting_exists (id )
570605
571606 # Create the rules
572607 rules = {
@@ -591,96 +626,104 @@ def add_port_forward(self, id: str, host_ip: str, host_port: int, dest_ip: str,
591626 "policy" : "accept"
592627 }
593628 }
594- },
595- {
596- "add" : {
597- "chain" : {
598- "family" : "ip" ,
599- "table" : "nat" ,
600- "name" : "POSTROUTING" ,
601- "type" : "nat" ,
602- "hook" : "postrouting" ,
603- "prio" : 100 ,
604- "policy" : "accept"
605- }
629+ }
630+ ]
631+ }
632+
633+ # Only add POSTROUTING chain if it doesn't exist
634+ if not postrouting_exists :
635+ rules ["nftables" ].append ({
636+ "add" : {
637+ "chain" : {
638+ "family" : "ip" ,
639+ "table" : "nat" ,
640+ "name" : "POSTROUTING" ,
641+ "type" : "nat" ,
642+ "hook" : "postrouting" ,
643+ "prio" : 100 ,
644+ "policy" : "accept"
606645 }
607- },
608- {
609- "add" : {
610- " rule" : {
611- "family" : "ip" ,
612- "table " : "nat" ,
613- "chain " : "PREROUTING" ,
614- "comment " : f"machine_id= { id } host_port= { host_port } " ,
615- "expr " : [
616- {
617- "match " : {
618- "op " : "==" ,
619- "left" : {
620- "payload " : {
621- "protocol " : "ip " ,
622- "field " : "daddr"
623- }
624- } ,
625- "right " : host_ip
646+ }
647+ })
648+
649+ # Add PREROUTING rule
650+ rules [ "nftables" ]. append ({
651+ "add " : {
652+ "rule " : {
653+ "family " : "ip " ,
654+ "table " : "nat" ,
655+ "chain" : "PREROUTING" ,
656+ "comment " : f"machine_id= { id } host_port= { host_port } vm_port= { dest_port } " ,
657+ "expr " : [
658+ {
659+ "match " : {
660+ "op " : "== " ,
661+ "left " : {
662+ "payload" : {
663+ "protocol" : "ip" ,
664+ "field " : "daddr"
626665 }
627666 },
628- {
629- "match" : {
630- "op" : "==" ,
631- "left" : {
632- "payload " : {
633- "protocol " : protocol ,
634- "field " : "dport"
635- }
636- } ,
637- "right " : host_port
667+ "right" : host_ip
668+ }
669+ } ,
670+ {
671+ "match " : {
672+ "op " : "==" ,
673+ "left " : {
674+ "payload" : {
675+ "protocol" : protocol ,
676+ "field " : "dport"
638677 }
639678 },
640- {
641- "dnat" : {
642- "addr" : dest_ip ,
643- "port" : dest_port
644- }
645- }
646- ]
679+ "right" : host_port
680+ }
681+ },
682+ {
683+ "dnat" : {
684+ "addr" : dest_ip ,
685+ "port" : dest_port
686+ }
647687 }
648- }
649- },
650- {
651- "add" : {
652- "rule" : {
653- "family" : "ip" ,
654- "table" : "nat" ,
655- "chain" : "POSTROUTING" ,
656- "comment" : f"machine_id={ id } " ,
657- "expr" : [
658- {
659- "match" : {
660- "op" : "==" ,
661- "left" : {
662- "payload" : {
663- "protocol" : "ip" ,
664- "field" : "saddr"
665- }
666- },
667- "right" : {
668- "prefix" : {
669- "addr" : dest_ip ,
670- "len" : 32
671- }
688+ ]
689+ }
690+ }
691+ })
692+
693+ # Only add POSTROUTING rule if it doesn't already exist
694+ if not postrouting_exists :
695+ rules ["nftables" ].append ({
696+ "add" : {
697+ "rule" : {
698+ "family" : "ip" ,
699+ "table" : "nat" ,
700+ "chain" : "POSTROUTING" ,
701+ "comment" : f"machine_id={ id } " ,
702+ "expr" : [
703+ {
704+ "match" : {
705+ "op" : "==" ,
706+ "left" : {
707+ "payload" : {
708+ "protocol" : "ip" ,
709+ "field" : "saddr"
710+ }
711+ },
712+ "right" : {
713+ "prefix" : {
714+ "addr" : dest_ip ,
715+ "len" : 32
672716 }
673717 }
674- },
675- {
676- "masquerade" : None
677718 }
678- ]
679- }
719+ },
720+ {
721+ "masquerade" : None
722+ }
723+ ]
680724 }
681725 }
682- ]
683- }
726+ })
684727
685728 try :
686729 for rule in rules ["nftables" ]:
@@ -764,12 +807,13 @@ def delete_masquerade(self):
764807 except Exception as e :
765808 raise NetworkError (f"Failed to delete masquerade rule: { str (e )} " )
766809
767- def delete_port_forward (self , id : str , host_port : int ):
810+ def delete_port_forward (self , id : str , host_port : int , dest_port : int ):
768811 """Delete port forwarding rules.
769812
770813 Args:
771- host_port (int): Port being forwarded.
772- machine_id (str): ID of the machine for which port forwarding is being deleted.
814+ id (str): Machine ID for which port forwarding is being deleted.
815+ host_port (int): Host port being forwarded.
816+ dest_port (int): Destination port being forwarded to.
773817
774818 Raises:
775819 NetworkError: If deleting port forwarding rules fails.
@@ -791,7 +835,9 @@ def delete_port_forward(self, id: str, host_port: int):
791835 rule = item ['rule' ]
792836 comment = rule .get ('comment' , '' )
793837
794- if f"machine_id={ id } host_port={ host_port } " in comment :
838+ comment_matches = f"machine_id={ id } host_port={ host_port } vm_port={ dest_port } " in comment
839+
840+ if comment_matches :
795841 chain = rule .get ('chain' , '' ).upper ()
796842 handle = rule ['handle' ]
797843
0 commit comments