迭代器與生成器#

這兩個概念經常混淆,先釐清差異:

可迭代物件(Iterable)迭代器(Iterator)
定義可以用 for 遍歷的物件記住「目前位置」、逐個吐出元素的物件
範例liststrdictrangezip()map()filter() 的回傳值
轉換iter(iterable) → 迭代器next(iterator) → 下一個元素
可重複使用否(用完就空了)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
nums = [1, 2, 3]       # list 是可迭代物件,不是迭代器

it = iter(nums)        # 用 iter() 建立迭代器
print(next(it))  # 1
print(next(it))  # 2
print(next(it))  # 3
print(next(it))  # StopIteration

# list 可以重複迭代
for n in nums:
    print(n)
for n in nums:         # 再來一次也沒問題
    print(n)

# 迭代器用完就空了
it2 = iter(nums)
list(it2)              # [1, 2, 3]
list(it2)              # [](已經空了)

for 迴圈的真相#

for 迴圈背後其實做了這些事:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
for x in [1, 2, 3]:
    print(x)

# 等同於
_iter = iter([1, 2, 3])
while True:
    try:
        x = next(_iter)
        print(x)
    except StopIteration:
        break

為什麼 zip()、map()、filter() 要用 list() 轉?#

這些函式回傳的是 迭代器 ,不是串列。迭代器的好處是 惰性求值(lazy evaluation) :元素不會一次全部產生,而是需要時才計算,節省記憶體。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# map() 回傳迭代器
result = map(lambda x: x ** 2, range(10))
print(result)        # <map object at 0x...>
print(list(result))  # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

# 處理一百萬筆資料時,迭代器的優勢很明顯
big = map(lambda x: x ** 2, range(1_000_000))
# 不會佔用大量記憶體,需要時才計算

# 直接 for 迴圈也可以,不必轉成 list
for val in map(lambda x: x ** 2, range(5)):
    print(val)

生成器(Generator)#

生成器是一種 特殊的迭代器 ,用 yield 關鍵字定義。函式遇到 yield 時會暫停並吐出值,下次呼叫時從暫停處繼續:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def count_up(start, end):
    current = start
    while current <= end:
        yield current       # 暫停,吐出 current
        current += 1        # 下次從這裡繼續

gen = count_up(1, 5)
print(next(gen))  # 1
print(next(gen))  # 2
print(next(gen))  # 3

# 也可以用 for 迴圈
for n in count_up(1, 5):
    print(n)

與普通函式的比較#

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# 普通函式:一次把所有結果放進記憶體
def squares_list(n):
    return [x ** 2 for x in range(n)]

# 生成器函式:需要時才計算,節省記憶體
def squares_gen(n):
    for x in range(n):
        yield x ** 2

# 兩者的使用方式相同
for sq in squares_gen(5):
    print(sq)   # 0 1 4 9 16

# 但記憶體用量差很多(百萬筆資料時特別明顯)
import sys
print(sys.getsizeof(squares_list(1000)))   # ~8056 bytes
print(sys.getsizeof(squares_gen(1000)))    # ~104 bytes(只存生成器物件)

生成器運算式#

類似串列推導式,但用 () 包住,直接產生生成器:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# 串列推導式(立刻建立整個串列)
sq_list = [x ** 2 for x in range(10)]

# 生成器運算式(惰性求值)
sq_gen = (x ** 2 for x in range(10))

print(sq_list)  # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
print(sq_gen)   # <generator object <genexpr> at 0x...>
print(list(sq_gen))  # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

# 直接傳給 sum()、max() 等函式時不需要 list()
total = sum(x ** 2 for x in range(10))
print(total)  # 285

yield 的進階用法#

多個 yield#

1
2
3
4
5
6
7
8
9
def weekdays():
    yield "Monday"
    yield "Tuesday"
    yield "Wednesday"
    yield "Thursday"
    yield "Friday"

for day in weekdays():
    print(day)

yield from:委派給另一個可迭代物件#

1
2
3
4
5
6
def flatten(nested):
    for sublist in nested:
        yield from sublist   # 等同於 for item in sublist: yield item

data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]]
print(list(flatten(data)))  # [1, 2, 3, 4, 5, 6, 7, 8, 9]

實戰範例#

無限序列#

生成器可以產生無限序列,因為它不需要預先建立所有元素:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def fibonacci():
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b

# 取前 10 個費波那契數
gen = fibonacci()
fibs = [next(gen) for _ in range(10)]
print(fibs)  # [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]

分批處理大量資料#

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def read_in_chunks(data, chunk_size):
    """將大型資料切成小批次,逐批處理"""
    for i in range(0, len(data), chunk_size):
        yield data[i:i + chunk_size]

records = list(range(1, 101))   # 模擬 100 筆資料

for batch in read_in_chunks(records, chunk_size=10):
    print(f"處理第 {batch[0]}{batch[-1]} 筆")

# 處理第 1~10 筆
# 處理第 11~20 筆
# ...
# 處理第 91~100 筆

移動平均線#

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
def moving_average(prices, window):
    """用生成器計算移動平均,不需要一次把所有結果存起來"""
    buffer = []
    for price in prices:
        buffer.append(price)
        if len(buffer) > window:
            buffer.pop(0)
        if len(buffer) == window:
            yield sum(buffer) / window

prices = [100, 102, 98, 105, 110, 108, 112, 115]
ma3 = list(moving_average(prices, window=3))
print([round(x, 2) for x in ma3])
# [100.0, 101.67, 104.33, 107.67, 110.0, 111.67]