Skip to content

Commit d342902

Browse files
committed
🔧 Adding
Visitor Design Pattern
1 parent e6eaaf9 commit d342902

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
# ------------- ELEMENT INTERFACE -------------
5+
class GeoNode(ABC):
6+
@abstractmethod
7+
def accept(self, visitor: 'GeoVisitor') -> None:
8+
...
9+
10+
11+
# ------------- CONCRETE ELEMENTS -------------
12+
class City(GeoNode):
13+
def __init__(self, name: str, population: int):
14+
self.name = name
15+
self.population = population
16+
17+
def accept(self, visitor: 'GeoVisitor') -> None:
18+
visitor.visit_city(self)
19+
20+
21+
class Industry(GeoNode):
22+
def __init__(self, sector: str, annual_output: float):
23+
self.sector = sector
24+
self.annual_output = annual_output
25+
26+
def accept(self, visitor: 'GeoVisitor') -> None:
27+
visitor.visit_industry(self)
28+
29+
30+
class SightSeeing(GeoNode):
31+
def __init__(self, name: str, rating: float):
32+
self.name = name
33+
self.rating = rating
34+
35+
def accept(self, visitor: 'GeoVisitor') -> None:
36+
visitor.visit_sightseeing(self)
37+
38+
39+
# ------------- VISITOR INTERFACE -------------
40+
class GeoVisitor(ABC):
41+
@abstractmethod
42+
def visit_city(self, city: City) -> None:
43+
...
44+
45+
@abstractmethod
46+
def visit_industry(self, industry: Industry) -> None:
47+
...
48+
49+
@abstractmethod
50+
def visit_sightseeing(self, sight: SightSeeing) -> None:
51+
...
52+
53+
54+
# ------------- CONCRETE VISITOR: XML Export -------------
55+
class XMLExportVisitor(GeoVisitor):
56+
def visit_city(self, city: City) -> None:
57+
print(f"<city><name>{city.name}</name><population>{city.population}</population></city>")
58+
59+
def visit_industry(self, industry: Industry) -> None:
60+
print(f"<industry><sector>{industry.sector}</sector><output>{industry.annual_output}</output></industry>")
61+
62+
def visit_sightseeing(self, sight: SightSeeing) -> None:
63+
print(f"<sight><name>{sight.name}</name><rating>{sight.rating}</rating></sight>")
64+
65+
66+
# ------------- CONCRETE VISITOR: Stats Collector -------------
67+
class StatsCollectorVisitor(GeoVisitor):
68+
def __init__(self):
69+
self.total_population = 0
70+
self.total_output = 0
71+
self.total_sights = 0
72+
73+
def visit_city(self, city: City) -> None:
74+
self.total_population += city.population
75+
76+
def visit_industry(self, industry: Industry) -> None:
77+
self.total_output += industry.annual_output
78+
79+
def visit_sightseeing(self, sight: SightSeeing) -> None:
80+
self.total_sights += 1
81+
82+
def report(self):
83+
print("---- Statistics Report ----")
84+
print(f"Total Population: {self.total_population}")
85+
print(f"Total Industry Output: {self.total_output}")
86+
print(f"Total Sightseeing Locations: {self.total_sights}")
87+
88+
89+
# ------------- CLIENT CODE -------------
90+
def main():
91+
graph = [
92+
City("New York", 8000000),
93+
Industry("Tech", 500_000_000),
94+
SightSeeing("Statue of Liberty", 4.9),
95+
City("San Francisco", 870000),
96+
Industry("Biotech", 300_000_000),
97+
]
98+
99+
print("=== Exporting to XML ===")
100+
xml_exporter = XMLExportVisitor()
101+
for node in graph:
102+
node.accept(xml_exporter)
103+
104+
print("\n=== Collecting Stats ===")
105+
stats_collector = StatsCollectorVisitor()
106+
for node in graph:
107+
node.accept(stats_collector)
108+
stats_collector.report()
109+
110+
111+
if __name__ == "__main__":
112+
main()

0 commit comments

Comments
 (0)