策略模式(Strategy Pattern)是一种行为设计模式,它允许你定义一系列算法,将每个算法封装起来,并使它们可以相互替换。策略模式让算法的变化独立于使用算法的客户。
基本概念
策略模式包含以下主要角色:
-
Context(上下文): 维护一个对Strategy对象的引用
-
Strategy(策略): 定义所有支持的算法的公共接口
-
ConcreteStrategy(具体策略): 实现Strategy接口的具体算法
Python实现示例
from abc import ABC, abstractmethod
from typing import List
# 策略接口
class SortStrategy(ABC):
@abstractmethod
def sort(self, data: List) -> List:
pass
# 具体策略A: 快速排序
class QuickSortStrategy(SortStrategy):
def sort(self, data: List) -> List:
print("使用快速排序算法")
if len(data) <= 1:
return data
pivot = data[len(data) // 2]
left = [x for x in data if x < pivot]
middle = [x for x in data if x == pivot]
right = [x for x in data if x > pivot]
return self.sort(left) + middle + self.sort(right)
# 具体策略B: 冒泡排序
class BubbleSortStrategy(SortStrategy):
def sort(self, data: List) -> List:
print("使用冒泡排序算法")
n = len(data)
for i in range(n):
for j in range(0, n-i-1):
if data[j] > data[j+1]:
data[j], data[j+1] = data[j+1], data[j]
return data
# 上下文类
class Sorter:
def __init__(self, strategy: SortStrategy):
self._strategy = strategy
def set_strategy(self, strategy: SortStrategy):
self._strategy = strategy
def perform_sort(self, data: List) -> List:
return self._strategy.sort(data)
# 客户端代码
if __name__ == "__main__":
data = [5, 2, 8, 1, 9, 3]
sorter = Sorter(QuickSortStrategy())
print("初始策略:")
result = sorter.perform_sort(data.copy())
print("排序结果:", result)
print("\n切换策略:")
sorter.set_strategy(BubbleSortStrategy())
result = sorter.perform_sort(data.copy())
print("排序结果:", result)
输出示例
初始策略: 使用快速排序算法 排序结果: [1, 2, 3, 5, 8, 9] 切换策略: 使用冒泡排序算法 排序结果: [1, 2, 3, 5, 8, 9]
策略模式的优点
-
开闭原则:可以在不修改上下文代码的情况下引入新策略
-
消除条件语句:避免使用多重条件判断来选择算法
-
提高灵活性:运行时可以切换算法
-
算法复用:不同上下文可以共享策略对象
Python中的简化实现
Python支持一等函数(first-class functions),这使得策略模式可以更简洁地实现:
from typing import List, Callable
def quick_sort(data: List) -> List:
print("使用快速排序算法")
if len(data) <= 1:
return data
pivot = data[len(data) // 2]
left = [x for x in data if x < pivot]
middle = [x for x in data if x == pivot]
right = [x for x in data if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
def bubble_sort(data: List) -> List:
print("使用冒泡排序算法")
n = len(data)
for i in range(n):
for j in range(0, n-i-1):
if data[j] > data[j+1]:
data[j], data[j+1] = data[j+1], data[j]
return data
class Sorter:
def __init__(self, strategy: Callable[[List], List]):
self._strategy = strategy
def set_strategy(self, strategy: Callable[[List], List]):
self._strategy = strategy
def perform_sort(self, data: List) -> List:
return self._strategy(data)
# 使用
data = [5, 2, 8, 1, 9, 3]
sorter = Sorter(quick_sort)
print(sorter.perform_sort(data.copy()))
sorter.set_strategy(bubble_sort)
print(sorter.perform_sort(data.copy()))
这种实现利用了Python的函数特性,使代码更加简洁。
订单折扣策略实现示例
"""
*What is this pattern about?
Define a family of algorithms, encapsulate each one, and make them interchangeable.
Strategy lets the algorithm vary independently from clients that use it.
*TL;DR
Enables selecting an algorithm at runtime.
"""
from __future__ import annotations
from typing import Callable
class DiscountStrategyValidator: # Descriptor class for check perform
@staticmethod
def validate(obj: Order, value: Callable) -> bool:
try:
if obj.price - value(obj) < 0:
raise ValueError(
f"Discount cannot be applied due to negative price resulting. {value.__name__}"
)
except ValueError as ex:
print(str(ex))
return False
else:
return True
def __set_name__(self, owner, name: str) -> None:
self.private_name = f"_{name}"
def __set__(self, obj: Order, value: Callable = None) -> None:
if value and self.validate(obj, value):
setattr(obj, self.private_name, value)
else:
setattr(obj, self.private_name, None)
def __get__(self, obj: object, objtype: type = None):
return getattr(obj, self.private_name)
class Order:
discount_strategy = DiscountStrategyValidator()
def __init__(self, price: float, discount_strategy: Callable = None) -> None:
self.price: float = price
self.discount_strategy = discount_strategy
def apply_discount(self) -> float:
if self.discount_strategy:
discount = self.discount_strategy(self)
else:
discount = 0
return self.price - discount
def __repr__(self) -> str:
strategy = getattr(self.discount_strategy, "__name__", None)
return f"<Order price: {self.price} with discount strategy: {strategy}>"
def ten_percent_discount(order: Order) -> float:
return order.price * 0.10
def on_sale_discount(order: Order) -> float:
return order.price * 0.25 + 20
def main():
"""
>>> order = Order(100, discount_strategy=ten_percent_discount)
>>> print(order)
<Order price: 100 with discount strategy: ten_percent_discount>
>>> print(order.apply_discount())
90.0
>>> order = Order(100, discount_strategy=on_sale_discount)
>>> print(order)
<Order price: 100 with discount strategy: on_sale_discount>
>>> print(order.apply_discount())
55.0
>>> order = Order(10, discount_strategy=on_sale_discount)
Discount cannot be applied due to negative price resulting. on_sale_discount
>>> print(order)
<Order price: 10 with discount strategy: None>
"""
if __name__ == "__main__":
import doctest
doctest.testmod()
正如代码开头的注释所述:
-
策略模式定义了一系列算法(折扣策略)
-
将每个算法封装起来
-
使它们可以互换
-
让算法的变化独立于使用它的客户端(Order类)
主要组件
1 DiscountStrategyValidator 类
这是一个描述符类,用于验证折扣策略的有效性:
class DiscountStrategyValidator:
@staticmethod
def validate(obj: Order, value: Callable) -> bool:
try:
if obj.price - value(obj) < 0:
raise ValueError(
f"Discount cannot be applied due to negative price resulting. {value.__name__}"
)
except ValueError as ex:
print(str(ex))
return False
else:
return True
def __set_name__(self, owner, name: str) -> None:
self.private_name = f"_{name}"
def __set__(self, obj: Order, value: Callable = None) -> None:
if value and self.validate(obj, value):
setattr(obj, self.private_name, value)
else:
setattr(obj, self.private_name, None)
def __get__(self, obj: object, objtype: type = None):
return getattr(obj, self.private_name)
功能解析:
-
validate
方法检查应用折扣后价格是否为负 -
__set_name__
设置私有属性名(约定为_加属性名) -
__set__
在设置值前进行验证 -
__get__
获取属性值
2 Order 类
这是上下文类,使用策略模式:
class Order:
discount_strategy = DiscountStrategyValidator()
def __init__(self, price: float, discount_strategy: Callable = None) -> None:
self.price: float = price
self.discount_strategy = discount_strategy
def apply_discount(self) -> float:
if self.discount_strategy:
discount = self.discount_strategy(self)
else:
discount = 0
return self.price - discount
def __repr__(self) -> str:
strategy = getattr(self.discount_strategy, "__name__", None)
return f"<Order price: {self.price} with discount strategy: {strategy}>"
关键点:
-
使用描述符管理折扣策略
-
apply_discount
方法应用当前策略 -
可以动态更换策略
3 具体策略实现
def ten_percent_discount(order: Order) -> float:
return order.price * 0.10
def on_sale_discount(order: Order) -> float:
return order.price * 0.25 + 20
这是两个具体策略:
-
ten_percent_discount
: 10%折扣 -
on_sale_discount
: 25%折扣外加20元优惠
使用示例
def main():
order = Order(100, discount_strategy=ten_percent_discount)
print(order)
print(order.apply_discount())
order = Order(100, discount_strategy=on_sale_discount)
print(order)
print(order.apply_discount())
order = Order(10, discount_strategy=on_sale_discount)
Discount cannot be applied due to negative price resulting. on_sale_discount
print(order)
测试案例说明:
-
第一个订单应用10%折扣,价格从100降到90
-
第二个订单应用促销折扣,价格从100降到55
-
第三个订单价格太低(10元),应用促销折扣会导致负价格,策略被拒绝
设计亮点
-
描述符的使用:通过描述符实现了策略的自动验证
-
类型提示:全面使用类型注解,提高代码可读性
-
防御性编程:防止无效折扣导致负价格
-
文档测试:使用doctest模块提供自包含的测试案例
与传统策略模式的对比
传统策略模式通常使用抽象基类定义接口,而这里利用了Python的鸭子类型和函数特性:
-
策略只是可调用对象(Callable)
-
不需要继承抽象基类
-
更符合Python的简洁哲学