Python MultiProcessing 使用心得

由于 GIL 的缘故,Python 下计算密集型并行一般推荐使用进程并行,如 multiprocessing。但是 multiprocessing 的使用中实在是有很多 tricky 的地方,不小心就会掉坑里。说到底,主要还是因为对它的具体实现机制不够了解。以下是个人使用中总结的一些心得,主要关注进程池和共享内存。

请注意,本文为 linux/mac 下的使用经验,windows 没有 fork 机制,部分地方可能不同。另外,本文假定读者已经知道什么是 Pool,map,apply

注意程序组织

要并行调用的函数和参数都应在 Pool 实例化之前定义好,并且要注意可序列化(pickle)性,避免使用 lambda 匿名函数。在实例化 Pool 对象时,fork 已经完成。之后主进程中变量更新不会影响子进程。
进程池一定要在被调函数定义之后实例化

什么可以被 pickle

参见what is pickable

  • None,True,False,内建数字类型,字符串

  • picklable 对象组成的 tuple、list、set、dict

  • 在模块顶层中定义的函数、内建函数、类

  • __dict____getstate__() 是可 pickle 的类实例

  • 特别地,numpy ndarray

注意 lambda 匿名函数不可被 pickle,除非使用 dill 之类的包,见
pickle 并不是什么特别神秘的东西,心态要放松,别怕。

遇到不能 pickle 的东西

  • 用可 pickle 的对象包装一下。
    比如,实例方法不可pickle,那我们可以传递实例,再调用方法

def work(foo):
    foo.work()
pool.apply_async(work,args=(foo,))
  • 如果是实例,可修改其 __getstate__ 方法,过滤掉不可 pickle 的东西。

Decoration

装饰器是非常实用的小技巧,可惜装饰器不能被 pickle。原因是 pickle 通过函数名在模块中查找函数,decorated function 的 __name__ 属性是它定义时的名字,一般不是它现在的变量名。幸运的是,很容易用 functools 绕过这个问题,解决方案见此
可以在下面 map 多参数 一节找到例子。

共享只读变量

原则上说,一般尽量避免使用全局变量,还是通过 map 或者 apply 将参数传给 Pool 对象比较规范。不过实践中,为了速度,有时只能做出一点让步。参数是通过 pickle 传递的,效率较低;而全局变量是在 fock 时直接继承,速度很快,某些时候速度可能会相差量级……
有需要时,可以在 Pool 初始化前定义全局变量,之后可以在各子进程中均可访问。注意,在子进程中对全局变量的写入对主进程是没有影响的。

共享可写内存

在 Process 或 Pool 实例化时,使用了 os.fork(),当前工作内存会被复制到新的进程中,所以当前的全局变量在子进程中都可以访问。 fork 有个特性,叫 copy-on-write,就是说其实是你修改了数据时,才真正进行内存复制(man fork)。不过考虑到在 Python 中对象的引用数总是在变化的,所以恐怕多数时候还是得复制的(参考这里)。
注意 fork 以后,各进程就是独立的了,某进程中的修改不会反映到另外的进程中去。如果想在不同进程中共享数据,要用队列传递数据,或者使用包中提供的 ValueArray 类,见此
当然,对数组有更简单的方法,强烈推荐使用 sharedmem 包(文档在此),它利用内存映射 memmap,使你在不同进程中访问同一块内存。使用感觉类似于 OpenMP,不过注意这里的数据操作没有原子性,如果有必要得手动加锁(并损失性能)。
其实 sharedmem 还提供了 MapReduce 类,有点类似 Pool,只是功能更多更强,对 pickle 没有要求,是很不错的替代选择。
PS: 我还给这个包贡献过两个辅助函数 :)
PSS: 有意思的是,这个包的作者也是做天体物理和宇宙学的。难怪了,需求类似啊。
PSSS: 另一个类似的包shmarray

获得当前进程

def func(i):
    p = multiprocessing.current_process()
    print i, p.name, p._identity, p.pid, os.getpid(), os.getppid()

其结果类似于:

1 PoolWorker-2 (2,) 1595 1595 1593

p._identity 是个不错的识别标志:

  • 主进程为 ()

  • () 的第一个子进程为 (1,),第二个子进程为 (2,)

  • (1,) 的第一个子进程为 (1,1)

异常

异常只在返回值的时候才会 raise,也就是调用 get() 的时候。如果你用 async 方法,又没有 get,错误有可能发生得悄无声息。

中断进程

按 ctrl-C 无法中断主程序,产生大量僵尸, get 的时候加上超时即可解决:

res = async_res.get(1e6)

或者直接

res = pool.map_async(...).get(1e6)

这个超时指定一个比任务预计完成时间长的数就好了。1e6 秒是 11 天,相当可观了。
这里这里 还列出了其他方法,其中上面的方法应该是最简单的。

回调函数 callback

apply_asyncmap_async 可选参数 callback,将在主进程中调用它来处理返回结果。注意若 callback 耗时太长会阻塞主进程。

map 多参数

内建 map 支持多参数,而 multiprocessing.map 就不行了。快速解决方案有二:

  • Python3 中可以使用 multiprocessing.starmap

  • 使用函数包装,例:

def func(x, y):
    return x + y
def func_wrap(args):
    return func(*args)
pool.map(func_wrap, zip(xlist, ylist))
  • 我们可以用装饰器做得更漂亮:

def unpack_args(func):
    from functools import wraps
    @wraps(func)
    def wrapper(args):
        if isinstance(args, dict):
            return func(**args)
        else:
            return func(*args)
    return wrapper

使用方法:在目标函数定义前加装饰 @unpack_args ,见下例

@unpack_args
def func(x, y):
    return x + y

np, xlist, ylist = 2, range(10), range(10)
pool = Pool(np)
res = pool.map(func, zip(xlist, ylist))
pool.close()
pool.join()

更多讨论见 SO

留意 multiprocessing.dummy

multiprocessing.dummy 是线程版的 multiprocessing ,两者 API 完全一样。
在处理 IO 密集型任务时可以考虑使用此模块。

随机数

因为 multiprocessing 子进程会拷贝主进程的状态,使用 numpy.random.rand 等函数生成伪随机数时,不同进程中会生成相同的随机数,这就失去随机数的意义了。
解决方案是在执行 numpy.random.rand 前,先执行 numpy.random.seed(),它会利用系统时间等方法刷新 seed。这样各进程将拥有不同的随机种子。

遍历参数组合

使用 itertools.product 遍历参数组合。
例:函数 func(a, b) 想遍历 a_list = [a1, ..., an]b_list = [b1, ..., bn] 中的所有 [ai, bj] 可能的组合

from itertools import product
args = product(a_list, b_list)
map(func, args)

itertools 里面还有不少好东西,比如 combinationspermutations

有时想在并行中让部分代码单进程执行,比如写入文件,多进程同时修改同一文件可能损坏数据,因而只允许一个进程操作。下面是一个简单例子,只用一个全局锁。

from multiprocessing import Pool, Lock
def func(arg):
    do_parallel_thing()
    with lock:
        do_serial_thing()

lock = Lock()
pool = Pool(nproc)
res = pool.map_async(func, args).get(time)
pool.close()
pool.join()

锁可以有非常复杂的玩法,不过这个例子足以应付多数需要。

其他参考材料

替代包

练习

试预测下面代码中各处 print 打印的值。

import os
import multiprocessing
from multiprocessing import Pool

def func(i):
    global a
    p = multiprocessing.current_process()
    print i, p.name, p._identity, p.pid, os.getpid(), os.getppid(),  b,  a
    a += 10

a, b = 1, 1
pool = Pool(2)

a, b = 2, 2
print '--------'
pool.map(func, range(4))

pool = Pool(2)
print '--------'
pool.map(func, range(4))

def func(i, c=[]):
    p = multiprocessing.current_process()
    print i, p.name, p._identity, p.pid, os.getpid(), c
    c.append(i)

print '--------'
pool.map(func,  range(4))

pool = Pool(2)
print '--------'
pool.map(func,   range(4))

标签: python, parallel

赞 (42)