Python MultiProcessing 使用心得
由于 GIL 的缘故,Python 下计算密集型并行一般推荐使用进程并行,如 multiprocessing。但是 multiprocessing 的使用中实在是有很多 tricky 的地方,不小心就会掉坑里。说到底,主要还是因为对它的具体实现机制不够了解。以下是个人使用中总结的一些心得,主要关注进程池和共享内存。
请注意,本文为 linux/mac 下的使用经验,windows 没有 fork 机制,部分地方可能不同。另外,本文假定读者已经知道什么是 Pool,map,apply
。
注意程序组织
要并行调用的函数和参数都应在 Pool 实例化之前定义好,并且要注意可序列化(pickle)性,避免使用 lambda
匿名函数。在实例化 Pool 对象时,fork 已经完成。之后主进程中变量更新不会影响子进程。
进程池一定要在被调函数定义之后实例化。
什么可以被 pickle
None,True,False,内建数字类型,字符串
picklable 对象组成的 tuple、list、set、dict
在模块顶层中定义的函数、内建函数、类
其
__dict__
或__getstate__()
是可 pickle 的类实例特别地,numpy ndarray
注意 lambda
匿名函数不可被 pickle,除非使用 dill
之类的包,见此。
pickle 并不是什么特别神秘的东西,心态要放松,别怕。
遇到不能 pickle 的东西
用可 pickle 的对象包装一下。
比如,实例方法不可pickle,那我们可以传递实例,再调用方法
def work(foo):
foo.work()
pool.apply_async(work,args=(foo,))
如果是实例,可修改其
__getstate__
方法,过滤掉不可 pickle 的东西。
Decoration
装饰器是非常实用的小技巧,可惜装饰器不能被 pickle。原因是 pickle 通过函数名在模块中查找函数,decorated function 的 __name__
属性是它定义时的名字,一般不是它现在的变量名。幸运的是,很容易用 functools
绕过这个问题,解决方案见此。
可以在下面 map 多参数 一节找到例子。
共享只读变量
原则上说,一般尽量避免使用全局变量,还是通过 map
或者 apply
将参数传给 Pool
对象比较规范。不过实践中,为了速度,有时只能做出一点让步。参数是通过 pickle
传递的,效率较低;而全局变量是在 fock
时直接继承,速度很快,某些时候速度可能会相差量级……
有需要时,可以在 Pool
初始化前定义全局变量,之后可以在各子进程中均可访问。注意,在子进程中对全局变量的写入对主进程是没有影响的。
共享可写内存
在 Process 或 Pool 实例化时,使用了 os.fork(),当前工作内存会被复制到新的进程中,所以当前的全局变量在子进程中都可以访问。 fork
有个特性,叫 copy-on-write
,就是说其实是你修改了数据时,才真正进行内存复制(man fork)。不过考虑到在 Python 中对象的引用数总是在变化的,所以恐怕多数时候还是得复制的(参考这里)。
注意 fork 以后,各进程就是独立的了,某进程中的修改不会反映到另外的进程中去。如果想在不同进程中共享数据,要用队列传递数据,或者使用包中提供的 Value
和 Array
类,见此。
当然,对数组有更简单的方法,强烈推荐使用 sharedmem 包(文档在此),它利用内存映射 memmap,使你在不同进程中访问同一块内存。使用感觉类似于 OpenMP,不过注意这里的数据操作没有原子性,如果有必要得手动加锁(并损失性能)。
其实 sharedmem 还提供了 MapReduce
类,有点类似 Pool
,只是功能更多更强,对 pickle 没有要求,是很不错的替代选择。
PS: 我还给这个包贡献过两个辅助函数 :)
PSS: 有意思的是,这个包的作者也是做天体物理和宇宙学的。难怪了,需求类似啊。
PSSS: 另一个类似的包shmarray
获得当前进程
def func(i):
p = multiprocessing.current_process()
print i, p.name, p._identity, p.pid, os.getpid(), os.getppid()
其结果类似于:
1 PoolWorker-2 (2,) 1595 1595 1593
p._identity 是个不错的识别标志:
主进程为
()
()
的第一个子进程为(1,)
,第二个子进程为(2,)
(1,)
的第一个子进程为(1,1)
异常
异常只在返回值的时候才会 raise,也就是调用 get() 的时候。如果你用 async 方法,又没有 get,错误有可能发生得悄无声息。
中断进程
按 ctrl-C 无法中断主程序,产生大量僵尸, get 的时候加上超时即可解决:
res = async_res.get(1e6)
或者直接
res = pool.map_async(...).get(1e6)
这个超时指定一个比任务预计完成时间长的数就好了。1e6 秒是 11 天,相当可观了。
这里 和这里 还列出了其他方法,其中上面的方法应该是最简单的。
回调函数 callback
apply_async
和 map_async
可选参数 callback
,将在主进程中调用它来处理返回结果。注意若 callback
耗时太长会阻塞主进程。
map 多参数
内建 map 支持多参数,而 multiprocessing.map 就不行了。快速解决方案有二:
Python3 中可以使用 multiprocessing.starmap
使用函数包装,例:
def func(x, y):
return x + y
def func_wrap(args):
return func(*args)
pool.map(func_wrap, zip(xlist, ylist))
我们可以用装饰器做得更漂亮:
def unpack_args(func):
from functools import wraps
@wraps(func)
def wrapper(args):
if isinstance(args, dict):
return func(**args)
else:
return func(*args)
return wrapper
使用方法:在目标函数定义前加装饰 @unpack_args
,见下例
@unpack_args
def func(x, y):
return x + y
np, xlist, ylist = 2, range(10), range(10)
pool = Pool(np)
res = pool.map(func, zip(xlist, ylist))
pool.close()
pool.join()
更多讨论见 SO。
留意 multiprocessing.dummy
multiprocessing.dummy
是线程版的 multiprocessing
,两者 API 完全一样。
在处理 IO 密集型任务时可以考虑使用此模块。
随机数
因为 multiprocessing
子进程会拷贝主进程的状态,使用 numpy.random.rand
等函数生成伪随机数时,不同进程中会生成相同的随机数,这就失去随机数的意义了。
解决方案是在执行 numpy.random.rand
前,先执行 numpy.random.seed()
,它会利用系统时间等方法刷新 seed
。这样各进程将拥有不同的随机种子。
遍历参数组合
使用 itertools.product
遍历参数组合。
例:函数 func(a, b)
想遍历 a_list = [a1, ..., an]
与 b_list = [b1, ..., bn]
中的所有 [ai, bj]
可能的组合
from itertools import product
args = product(a_list, b_list)
map(func, args)
itertools
里面还有不少好东西,比如 combinations
和 permutations
。
锁
有时想在并行中让部分代码单进程执行,比如写入文件,多进程同时修改同一文件可能损坏数据,因而只允许一个进程操作。下面是一个简单例子,只用一个全局锁。
from multiprocessing import Pool, Lock
def func(arg):
do_parallel_thing()
with lock:
do_serial_thing()
lock = Lock()
pool = Pool(nproc)
res = pool.map_async(func, args).get(time)
pool.close()
pool.join()
锁可以有非常复杂的玩法,不过这个例子足以应付多数需要。
其他参考材料
A good note on (saving) memory:里面关于 spwan 的机制值得一看(windows 和 python3 下的 linux 可用),可以通过
if __name__ == "__main__"
区分子进程,并避免 if 下面定义的对象被复制。
替代包
Pathos
值得留意。
练习
试预测下面代码中各处 print
打印的值。
import os
import multiprocessing
from multiprocessing import Pool
def func(i):
global a
p = multiprocessing.current_process()
print i, p.name, p._identity, p.pid, os.getpid(), os.getppid(), b, a
a += 10
a, b = 1, 1
pool = Pool(2)
a, b = 2, 2
print '--------'
pool.map(func, range(4))
pool = Pool(2)
print '--------'
pool.map(func, range(4))
def func(i, c=[]):
p = multiprocessing.current_process()
print i, p.name, p._identity, p.pid, os.getpid(), c
c.append(i)
print '--------'
pool.map(func, range(4))
pool = Pool(2)
print '--------'
pool.map(func, range(4))
仅有一条评论
添加新评论
- 上一篇: 上海坐标 - 城市定向挑战赛线路
- 下一篇: Homebrew 使用
[...]推荐阅读:http://luly.lamost.org/blog/python_multiprocessing.html[...]