Python 工厂模式
工厂模式是一种创建型设计模式,它提供了一种创建对象的最佳方式。
想象一下你去餐厅点餐:你不需要知道厨房如何制作食物,只需要告诉服务员你想要什么,厨房就会为你制作出来。
工厂模式就是这样一个"厨房",它负责创建对象,而你只需要告诉它你想要什么类型的对象。
为什么需要工厂模式?
在编程中,我们经常需要创建对象。如果直接在代码中使用 new 关键字或类的构造函数,会导致:
- 代码耦合度高:创建对象的代码与具体类紧密绑定
- 维护困难:当需要修改或添加新的对象类型时,需要修改多处代码
- 违反开闭原则:对扩展开放,对修改关闭的原则被破坏
工厂模式通过将对象的创建过程封装起来,解决了这些问题。
工厂模式的三种类型
工厂模式主要分为三种类型,让我们通过具体的例子来理解它们。
1. 简单工厂模式
简单工厂模式是最基础的工厂模式,它通过一个工厂类来创建不同类型的对象。
基本结构
实例
from abc import ABC, abstractmethod
# 产品接口
class Animal(ABC):
@abstractmethod
def speak(self):
pass
# 具体产品
class Dog(Animal):
def speak(self):
return "汪汪!"
class Cat(Animal):
def speak(self):
return "喵喵!"
class Duck(Animal):
def speak(self):
return "嘎嘎!"
# 简单工厂
class AnimalFactory:
@staticmethod
def create_animal(animal_type):
if animal_type == "dog":
return Dog()
elif animal_type == "cat":
return Cat()
elif animal_type == "duck":
return Duck()
else:
raise ValueError(f"未知的动物类型: {animal_type}")
# 使用示例
def test_simple_factory():
factory = AnimalFactory()
dog = factory.create_animal("dog")
cat = factory.create_animal("cat")
duck = factory.create_animal("duck")
print(dog.speak()) # 输出: 汪汪!
print(cat.speak()) # 输出: 喵喵!
print(duck.speak()) # 输出: 嘎嘎!
if __name__ == "__main__":
test_simple_factory()
# 产品接口
class Animal(ABC):
@abstractmethod
def speak(self):
pass
# 具体产品
class Dog(Animal):
def speak(self):
return "汪汪!"
class Cat(Animal):
def speak(self):
return "喵喵!"
class Duck(Animal):
def speak(self):
return "嘎嘎!"
# 简单工厂
class AnimalFactory:
@staticmethod
def create_animal(animal_type):
if animal_type == "dog":
return Dog()
elif animal_type == "cat":
return Cat()
elif animal_type == "duck":
return Duck()
else:
raise ValueError(f"未知的动物类型: {animal_type}")
# 使用示例
def test_simple_factory():
factory = AnimalFactory()
dog = factory.create_animal("dog")
cat = factory.create_animal("cat")
duck = factory.create_animal("duck")
print(dog.speak()) # 输出: 汪汪!
print(cat.speak()) # 输出: 喵喵!
print(duck.speak()) # 输出: 嘎嘎!
if __name__ == "__main__":
test_simple_factory()
优点与缺点
优点:
- 客户端与具体产品类解耦
- 职责分离,易于维护
缺点:
- 添加新产品需要修改工厂类,违反开闭原则
- 工厂类职责过重,不符合单一职责原则
2. 工厂方法模式
工厂方法模式通过让子类决定创建什么对象来解决简单工厂模式的问题。
基本结构
实例
from abc import ABC, abstractmethod
# 产品接口
class Button(ABC):
@abstractmethod
def render(self):
pass
@abstractmethod
def onClick(self):
pass
# 具体产品
class WindowsButton(Button):
def render(self):
return "渲染 Windows 风格按钮"
def onClick(self):
return "Windows 按钮被点击"
class MacButton(Button):
def render(self):
return "渲染 Mac 风格按钮"
def onClick(self):
return "Mac 按钮被点击"
# 创建者抽象类
class Dialog(ABC):
@abstractmethod
def createButton(self) -> Button:
pass
def render(self):
# 调用工厂方法创建产品
button = self.createButton()
result = button.render()
return result
# 具体创建者
class WindowsDialog(Dialog):
def createButton(self) -> Button:
return WindowsButton()
class MacDialog(Dialog):
def createButton(self) -> Button:
return MacButton()
# 使用示例
def test_factory_method():
# 根据配置选择具体的工厂
config = "windows" # 可以从配置文件读取
if config == "windows":
dialog = WindowsDialog()
else:
dialog = MacDialog()
result = dialog.render()
print(result)
if __name__ == "__main__":
test_factory_method()
# 产品接口
class Button(ABC):
@abstractmethod
def render(self):
pass
@abstractmethod
def onClick(self):
pass
# 具体产品
class WindowsButton(Button):
def render(self):
return "渲染 Windows 风格按钮"
def onClick(self):
return "Windows 按钮被点击"
class MacButton(Button):
def render(self):
return "渲染 Mac 风格按钮"
def onClick(self):
return "Mac 按钮被点击"
# 创建者抽象类
class Dialog(ABC):
@abstractmethod
def createButton(self) -> Button:
pass
def render(self):
# 调用工厂方法创建产品
button = self.createButton()
result = button.render()
return result
# 具体创建者
class WindowsDialog(Dialog):
def createButton(self) -> Button:
return WindowsButton()
class MacDialog(Dialog):
def createButton(self) -> Button:
return MacButton()
# 使用示例
def test_factory_method():
# 根据配置选择具体的工厂
config = "windows" # 可以从配置文件读取
if config == "windows":
dialog = WindowsDialog()
else:
dialog = MacDialog()
result = dialog.render()
print(result)
if __name__ == "__main__":
test_factory_method()
工厂方法模式流程

3. 抽象工厂模式
抽象工厂模式提供一个创建一系列相关或依赖对象的接口,而无需指定它们具体的类。
基本结构
实例
from abc import ABC, abstractmethod
# 抽象产品 A
class Button(ABC):
@abstractmethod
def paint(self):
pass
# 抽象产品 B
class Checkbox(ABC):
@abstractmethod
def paint(self):
pass
# 具体产品 A1
class WindowsButton(Button):
def paint(self):
return "渲染 Windows 按钮"
# 具体产品 A2
class MacButton(Button):
def paint(self):
return "渲染 Mac 按钮"
# 具体产品 B1
class WindowsCheckbox(Checkbox):
def paint(self):
return "渲染 Windows 复选框"
# 具体产品 B2
class MacCheckbox(Checkbox):
def paint(self):
return "渲染 Mac 复选框"
# 抽象工厂
class GUIFactory(ABC):
@abstractmethod
def createButton(self) -> Button:
pass
@abstractmethod
def createCheckbox(self) -> Checkbox:
pass
# 具体工厂 1
class WindowsFactory(GUIFactory):
def createButton(self) -> Button:
return WindowsButton()
def createCheckbox(self) -> Checkbox:
return WindowsCheckbox()
# 具体工厂 2
class MacFactory(GUIFactory):
def createButton(self) -> Button:
return MacButton()
def createCheckbox(self) -> Checkbox:
return MacCheckbox()
# 客户端代码
class Application:
def __init__(self, factory: GUIFactory):
self.factory = factory
self.button = None
self.checkbox = None
def createUI(self):
self.button = self.factory.createButton()
self.checkbox = self.factory.createCheckbox()
def paint(self):
result = []
if self.button:
result.append(self.button.paint())
if self.checkbox:
result.append(self.checkbox.paint())
return "\n".join(result)
# 使用示例
def test_abstract_factory():
# 根据系统类型选择工厂
system_type = "windows" # 可以自动检测或从配置读取
if system_type == "windows":
factory = WindowsFactory()
else:
factory = MacFactory()
app = Application(factory)
app.createUI()
print(app.paint())
if __name__ == "__main__":
test_abstract_factory()
# 抽象产品 A
class Button(ABC):
@abstractmethod
def paint(self):
pass
# 抽象产品 B
class Checkbox(ABC):
@abstractmethod
def paint(self):
pass
# 具体产品 A1
class WindowsButton(Button):
def paint(self):
return "渲染 Windows 按钮"
# 具体产品 A2
class MacButton(Button):
def paint(self):
return "渲染 Mac 按钮"
# 具体产品 B1
class WindowsCheckbox(Checkbox):
def paint(self):
return "渲染 Windows 复选框"
# 具体产品 B2
class MacCheckbox(Checkbox):
def paint(self):
return "渲染 Mac 复选框"
# 抽象工厂
class GUIFactory(ABC):
@abstractmethod
def createButton(self) -> Button:
pass
@abstractmethod
def createCheckbox(self) -> Checkbox:
pass
# 具体工厂 1
class WindowsFactory(GUIFactory):
def createButton(self) -> Button:
return WindowsButton()
def createCheckbox(self) -> Checkbox:
return WindowsCheckbox()
# 具体工厂 2
class MacFactory(GUIFactory):
def createButton(self) -> Button:
return MacButton()
def createCheckbox(self) -> Checkbox:
return MacCheckbox()
# 客户端代码
class Application:
def __init__(self, factory: GUIFactory):
self.factory = factory
self.button = None
self.checkbox = None
def createUI(self):
self.button = self.factory.createButton()
self.checkbox = self.factory.createCheckbox()
def paint(self):
result = []
if self.button:
result.append(self.button.paint())
if self.checkbox:
result.append(self.checkbox.paint())
return "\n".join(result)
# 使用示例
def test_abstract_factory():
# 根据系统类型选择工厂
system_type = "windows" # 可以自动检测或从配置读取
if system_type == "windows":
factory = WindowsFactory()
else:
factory = MacFactory()
app = Application(factory)
app.createUI()
print(app.paint())
if __name__ == "__main__":
test_abstract_factory()
三种工厂模式对比
| 特性 | 简单工厂模式 | 工厂方法模式 | 抽象工厂模式 |
|---|---|---|---|
| 复杂度 | 低 | 中 | 高 |
| 扩展性 | 差 | 好 | 很好 |
| 适用场景 | 对象种类少 | 单一产品族 | 多个相关产品族 |
| 开闭原则 | 违反 | 遵守 | 遵守 |
| 依赖关系 | 依赖具体类 | 依赖抽象类 | 依赖抽象接口 |
实际应用场景
场景 1:数据库连接工厂
实例
from abc import ABC, abstractmethod
import sqlite3
import mysql.connector
# 数据库连接接口
class DatabaseConnection(ABC):
@abstractmethod
def connect(self):
pass
@abstractmethod
def execute(self, query):
pass
# 具体数据库连接
class SQLiteConnection(DatabaseConnection):
def __init__(self, db_path):
self.db_path = db_path
self.connection = None
def connect(self):
self.connection = sqlite3.connect(self.db_path)
return self.connection
def execute(self, query):
if self.connection:
cursor = self.connection.cursor()
cursor.execute(query)
return cursor.fetchall()
class MySQLConnection(DatabaseConnection):
def __init__(self, host, user, password, database):
self.host = host
self.user = user
self.password = password
self.database = database
self.connection = None
def connect(self):
self.connection = mysql.connector.connect(
host=self.host,
user=self.user,
password=self.password,
database=self.database
)
return self.connection
def execute(self, query):
if self.connection:
cursor = self.connection.cursor()
cursor.execute(query)
return cursor.fetchall()
# 数据库工厂
class DatabaseFactory:
@staticmethod
def create_connection(db_type, **kwargs):
if db_type == "sqlite":
return SQLiteConnection(**kwargs)
elif db_type == "mysql":
return MySQLConnection(**kwargs)
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
# 使用示例
def test_database_factory():
# 创建 SQLite 连接
sqlite_conn = DatabaseFactory.create_connection(
"sqlite",
db_path="example.db"
)
sqlite_conn.connect()
# 创建 MySQL 连接
mysql_conn = DatabaseFactory.create_connection(
"mysql",
host="localhost",
user="root",
password="password",
database="test"
)
mysql_conn.connect()
print("数据库连接创建成功!")
if __name__ == "__main__":
test_database_factory()
import sqlite3
import mysql.connector
# 数据库连接接口
class DatabaseConnection(ABC):
@abstractmethod
def connect(self):
pass
@abstractmethod
def execute(self, query):
pass
# 具体数据库连接
class SQLiteConnection(DatabaseConnection):
def __init__(self, db_path):
self.db_path = db_path
self.connection = None
def connect(self):
self.connection = sqlite3.connect(self.db_path)
return self.connection
def execute(self, query):
if self.connection:
cursor = self.connection.cursor()
cursor.execute(query)
return cursor.fetchall()
class MySQLConnection(DatabaseConnection):
def __init__(self, host, user, password, database):
self.host = host
self.user = user
self.password = password
self.database = database
self.connection = None
def connect(self):
self.connection = mysql.connector.connect(
host=self.host,
user=self.user,
password=self.password,
database=self.database
)
return self.connection
def execute(self, query):
if self.connection:
cursor = self.connection.cursor()
cursor.execute(query)
return cursor.fetchall()
# 数据库工厂
class DatabaseFactory:
@staticmethod
def create_connection(db_type, **kwargs):
if db_type == "sqlite":
return SQLiteConnection(**kwargs)
elif db_type == "mysql":
return MySQLConnection(**kwargs)
else:
raise ValueError(f"不支持的数据库类型: {db_type}")
# 使用示例
def test_database_factory():
# 创建 SQLite 连接
sqlite_conn = DatabaseFactory.create_connection(
"sqlite",
db_path="example.db"
)
sqlite_conn.connect()
# 创建 MySQL 连接
mysql_conn = DatabaseFactory.create_connection(
"mysql",
host="localhost",
user="root",
password="password",
database="test"
)
mysql_conn.connect()
print("数据库连接创建成功!")
if __name__ == "__main__":
test_database_factory()
场景 2:日志记录器工厂
实例
import logging
from abc import ABC, abstractmethod
import sys
# 日志记录器接口
class Logger(ABC):
@abstractmethod
def info(self, message):
pass
@abstractmethod
def error(self, message):
pass
@abstractmethod
def debug(self, message):
pass
# 控制台日志记录器
class ConsoleLogger(Logger):
def info(self, message):
print(f"INFO: {message}")
def error(self, message):
print(f"ERROR: {message}", file=sys.stderr)
def debug(self, message):
print(f"DEBUG: {message}")
# 文件日志记录器
class FileLogger(Logger):
def __init__(self, filename):
self.filename = filename
def info(self, message):
with open(self.filename, 'a') as f:
f.write(f"INFO: {message}\n")
def error(self, message):
with open(self.filename, 'a') as f:
f.write(f"ERROR: {message}\n")
def debug(self, message):
with open(self.filename, 'a') as f:
f.write(f"DEBUG: {message}\n")
# 日志工厂
class LoggerFactory:
@staticmethod
def get_logger(logger_type, **kwargs):
if logger_type == "console":
return ConsoleLogger()
elif logger_type == "file":
return FileLogger(**kwargs)
else:
raise ValueError(f"不支持的日志类型: {logger_type}")
# 使用示例
def test_logger_factory():
# 创建控制台日志记录器
console_logger = LoggerFactory.get_logger("console")
console_logger.info("这是一个信息消息")
console_logger.error("这是一个错误消息")
# 创建文件日志记录器
file_logger = LoggerFactory.get_logger("file", filename="app.log")
file_logger.info("记录到文件的信息")
file_logger.debug("调试信息")
if __name__ == "__main__":
test_logger_factory()
from abc import ABC, abstractmethod
import sys
# 日志记录器接口
class Logger(ABC):
@abstractmethod
def info(self, message):
pass
@abstractmethod
def error(self, message):
pass
@abstractmethod
def debug(self, message):
pass
# 控制台日志记录器
class ConsoleLogger(Logger):
def info(self, message):
print(f"INFO: {message}")
def error(self, message):
print(f"ERROR: {message}", file=sys.stderr)
def debug(self, message):
print(f"DEBUG: {message}")
# 文件日志记录器
class FileLogger(Logger):
def __init__(self, filename):
self.filename = filename
def info(self, message):
with open(self.filename, 'a') as f:
f.write(f"INFO: {message}\n")
def error(self, message):
with open(self.filename, 'a') as f:
f.write(f"ERROR: {message}\n")
def debug(self, message):
with open(self.filename, 'a') as f:
f.write(f"DEBUG: {message}\n")
# 日志工厂
class LoggerFactory:
@staticmethod
def get_logger(logger_type, **kwargs):
if logger_type == "console":
return ConsoleLogger()
elif logger_type == "file":
return FileLogger(**kwargs)
else:
raise ValueError(f"不支持的日志类型: {logger_type}")
# 使用示例
def test_logger_factory():
# 创建控制台日志记录器
console_logger = LoggerFactory.get_logger("console")
console_logger.info("这是一个信息消息")
console_logger.error("这是一个错误消息")
# 创建文件日志记录器
file_logger = LoggerFactory.get_logger("file", filename="app.log")
file_logger.info("记录到文件的信息")
file_logger.debug("调试信息")
if __name__ == "__main__":
test_logger_factory()
最佳实践和注意事项
1. 何时使用工厂模式?
适合使用工厂模式的情况:
- 创建对象的过程比较复杂
- 需要根据不同的条件创建不同的对象
- 希望将对象的创建与使用分离
- 系统需要支持多种类型的产品
不适合使用的情况:
- 对象的创建过程很简单,直接使用构造函数即可
- 产品类型很少,且不太可能扩展
2. 常见错误和避免方法
错误 1:过度设计
实例
# 不推荐:简单情况使用复杂工厂
class SimpleObject:
def __init__(self, name):
self.name = name
# 过度设计的工厂
class SimpleObjectFactory:
@staticmethod
def create_simple_object(name):
return SimpleObject(name)
# 推荐:直接创建
obj = SimpleObject("test")
class SimpleObject:
def __init__(self, name):
self.name = name
# 过度设计的工厂
class SimpleObjectFactory:
@staticmethod
def create_simple_object(name):
return SimpleObject(name)
# 推荐:直接创建
obj = SimpleObject("test")
错误 2:工厂类职责过多
实例
# 不推荐:一个工厂做太多事情
class GodFactory:
def create_user(self): ...
def create_order(self): ...
def create_product(self): ...
def send_email(self): ... # 这不是创建对象!
# 推荐:按职责分离
class UserFactory: ...
class OrderFactory: ...
class ProductFactory: ...
class GodFactory:
def create_user(self): ...
def create_order(self): ...
def create_product(self): ...
def send_email(self): ... # 这不是创建对象!
# 推荐:按职责分离
class UserFactory: ...
class OrderFactory: ...
class ProductFactory: ...
3. 与依赖注入的结合
工厂模式经常与依赖注入(DI)一起使用:
实例
from abc import ABC, abstractmethod
# 服务接口
class NotificationService(ABC):
@abstractmethod
def send(self, message):
pass
# 具体服务
class EmailService(NotificationService):
def send(self, message):
return f"发送邮件: {message}"
class SMSService(NotificationService):
def send(self, message):
return f"发送短信: {message}"
# 工厂
class NotificationFactory:
@staticmethod
def create_service(service_type):
if service_type == "email":
return EmailService()
elif service_type == "sms":
return SMSService()
else:
raise ValueError(f"未知的服务类型: {service_type}")
# 使用依赖注入的类
class OrderProcessor:
def __init__(self, notification_service: NotificationService):
self.notification_service = notification_service
def process_order(self, order):
# 处理订单逻辑
result = self.notification_service.send("订单处理完成")
return result
# 使用
def main():
# 通过工厂创建服务
notification_service = NotificationFactory.create_service("email")
# 注入依赖
processor = OrderProcessor(notification_service)
result = processor.process_order({"id": 1})
print(result)
if __name__ == "__main__":
main()
# 服务接口
class NotificationService(ABC):
@abstractmethod
def send(self, message):
pass
# 具体服务
class EmailService(NotificationService):
def send(self, message):
return f"发送邮件: {message}"
class SMSService(NotificationService):
def send(self, message):
return f"发送短信: {message}"
# 工厂
class NotificationFactory:
@staticmethod
def create_service(service_type):
if service_type == "email":
return EmailService()
elif service_type == "sms":
return SMSService()
else:
raise ValueError(f"未知的服务类型: {service_type}")
# 使用依赖注入的类
class OrderProcessor:
def __init__(self, notification_service: NotificationService):
self.notification_service = notification_service
def process_order(self, order):
# 处理订单逻辑
result = self.notification_service.send("订单处理完成")
return result
# 使用
def main():
# 通过工厂创建服务
notification_service = NotificationFactory.create_service("email")
# 注入依赖
processor = OrderProcessor(notification_service)
result = processor.process_order({"id": 1})
print(result)
if __name__ == "__main__":
main()
练习题目
练习 1:实现形状工厂
创建一个形状工厂,支持创建圆形(Circle)、矩形(Rectangle)和三角形(Triangle)。每个形状都应该有计算面积和周长的方怯。
要求:
- 使用工厂方法模式
- 每个形状类实现
calculate_area()和calculate_perimeter()方法 - 提供使用示例
练习 2:扩展数据库工厂
基于前面的数据库工厂示例,添加对 PostgreSQL 数据库的支持。
要求:
- 创建
PostgreSQLConnection类 - 修改工厂以支持新数据库类型
- 确保不破坏现有代码
练习 3:配置驱动的工厂
创建一个可以根据配置文件动态选择工厂的系统。
要求:
- 从 JSON 或 YAML 文件读取配置
- 根据配置创建相应的对象
- 支持热更新配置
总结
工厂模式是 Python 设计中非常重要的创建型模式,它通过将对象的创建过程封装起来,提供了以下好处:
- 降低耦合度:客户端不需要知道具体产品的创建细节
- 提高可维护性:创建逻辑集中管理,易于修改和扩展
- 增强灵活性:可以轻松添加新的产品类型
- 促进代码复用:创建逻辑可以在多个地方重用
记住选择适合的工厂模式类型:
- 简单工厂:适用于产品类型较少且不太变化的场景
- 工厂方法:适用于需要扩展产品族的场景
- 抽象工厂:适用于需要创建相关产品族的复杂场景
通过合理使用工厂模式,你可以编写出更加灵活、可维护和可扩展的 Python 代码!
