4

I've been trying to parallelize some code that I wrote in python. The actual work is embarrassingly parallel, but I don't have much experience with multiprocessing in Python.

The actual code I'm writing involves huge arrays, on the order of GB, so pickling is unfeasible (with that overhead pure vectorized numpy is faster). I've been trying to use the mp.shared_memory module, but the process crashes whenever I try to do anything with arrays in shared memory. I've reconstructed a minimal example below:

import numpy as np
import multiprocessing as mp

from multiprocessing import shared_memory


def worker_func(start, end):
    global test_array
    print(f"Worker processing {start}:{end}")
    print(f"Array shape: {test_array.shape}")
    print(f"Slicing [{start}:{end + 1}]...")
    slice_result = test_array[start:end + 1]
    print(f"Slice shape: {slice_result.shape}")
    return slice_result


def worker_init(shm_name, shape, dtype):
    global test_array
    print("Worker init...")
    shm = shared_memory.SharedMemory(name=shm_name)
    test_array = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
    print(f"Worker init done, shape: {test_array.shape}")


if __name__ == "__main__":
    shape = (10,) #whatever, fails for both 1 dim and high dimensional arrays
    data = np.random.randn(*shape) + 1j * np.random.randn(*shape)
    data = data.astype(np.complex128)

    # Put in shared memory
    shm = shared_memory.SharedMemory(create=True, size=data.nbytes)
    shared_arr = np.ndarray(shape, dtype=data.dtype, buffer=shm.buf)
    print("Copying data to shared memory...")
    shared_arr[:] = data[:]
    print("Copy done")
    end = np.size(data) -1

    #Multiprocessing with a single thread for test reasons
    with mp.Pool(processes=1, initializer=worker_init, initargs=(shm.name, shape, data.dtype)) as pool:
        print("Workers ready, submitting work...")
        result = pool.starmap(worker_func, [(0, end)])  # Your batch_size
        print(f"Got result, shape: {result.shape}")

    shm.close()
    shm.unlink()
    print("Done")

The output is:

Copy done
Workers ready, submitting work...
Worker init...
Worker init done, shape: (10,)
Worker processing 0:9
Array shape: (10,)
Slicing [0:10]...
Slice shape: (10,)
Worker init...
Worker init done, shape: (10,)

So it seems to run the init function twice, once after the starmap call has already happened. It also never reaches the "done" or the "got result" statement after the starmap call. When this sort of thing happens, it usually means that I'm missing something obvious. What am I misunderstanding here?

If it's relevant, I'm on a M2 macbook air, but the issue reproduces on the supercomputer that I'm developing this program for.

7
  • Whilst not relevant to the issue at hand, it's interesting to note that the initializer is called twice. This contradicts the documentation wherein it is stated "If initializer is not None then each worker process will call initializer(*initargs) when it starts." In this case it should only be called once Commented Oct 11 at 8:23
  • I’m not sure about the details, but the cause is likely that shm (in the worker_init) gets garbage collected while it's still being used via test_array. How about storing shm as a global variable along with test_array? Commented Oct 11 at 9:12
  • worker_init will be called for each pool process created. If you specify Pool(processes=1, ...) it will only be called once. Are you sure you didn't have at some point in your code processes=2? Commented Oct 11 at 11:38
  • 1
    @Booboo I have created a small example that shows that when processes==1, the initializer gets called twice. Having said that, I don't think it has anything to do with the real problem Commented Oct 11 at 12:34
  • @Ramrab You are correct (see my answer), but I suspect it is related to the problem since the initializer function should only be called at most N times where N is the size of the pool. Commented Oct 11 at 12:42

2 Answers 2

2

TL/DR

The reason why your worker_init function is being called more times than the number of processes you have in your pool, which is 1, is because your pool process is abnormally terminating and the pool needs to then re-create the pool process resulting in calling the initializer again. The reason why the process is terminating is because your initializer is not retaining in global storage a reference to shared memory. If you were to add a global shm statement to your worker_init function that would resolve that issue. The other problem is that your call to starmap returns a list, which in your case contains a single np array element, yet you are calling shape on this list rather than on an element of the list.

More Info

It may well be that when a shared memory handle is garbage collected a close is automatically called on the handle. If so, that would explain why your pool process might be abnormally terminating, but I don't think that is the case. So to ensure shared memory is released properly on all platforms to prevent memory leaks, we should ensure that a close is issued on the handle when the process terminates:

  1. Wrap shared memory is a special wrapper class SharedMemoryWrapper instance, which contains a __del__ function that will close the shared memory handle when the wrapper instance is garbage collected and that wrapper instance should be stored so that each pool process correctly maintains a reference to its own wrapper instance.
  2. You are using a with Pool(...) as pool: context manager. When this block exits an implicit call to pool.terminate() will be called to terminate each process. This could prevent the shared memory wrapper from being properly garbage collected and thus a close would not be closed on the shared memory handle. Instead, gracefully "terminate" the pool by classing pool.close() followed by pool.join() allowing each pool process to end normally.

When I make these changes your code runs correctly and terminates as expected.

import multiprocessing as mp
from multiprocessing import shared_memory

import numpy as np

class SharedMemoryWrapper:
    def __init__(self, shm_name, shape, dtype):
        print('accessing shared memory')
        self._shared_mem = shared_memory.SharedMemory(name=shm_name)
        self._arr = np.ndarray(shape, dtype=dtype, buffer=self._shared_mem.buf)

    @property
    def arr(self):
        return self._arr

    def __del__(self):
        print('shared memory is being closed')
        self._shared_mem.close()

def worker_func(start, end):
    global test_array
    print(f"Worker processing {start}:{end}")
    print(f"Array shape: {test_array.shape}")
    print(f"Slicing [{start}:{end + 1}]...")
    slice_result = test_array[start:end + 1]
    print(f"Slice shape: {slice_result.shape}")
    return slice_result

def worker_init(shm_name, shape, dtype):
    global test_array
    global shared_memory_wrapper

    print("Worker init...")
    shared_memory_wrapper = SharedMemoryWrapper(shm_name, shape, dtype)
    test_array = shared_memory_wrapper.arr
    print(f"Worker init done, shape: {test_array.shape}")

if __name__ == "__main__":
    shape = (10,) #whatever, fails for both 1 dim and high dimensional arrays
    data = np.random.randn(*shape) + 1j * np.random.randn(*shape)
    data = data.astype(np.complex128)

    # Put in shared memory
    shm = shared_memory.SharedMemory(create=True, size=data.nbytes)
    shared_arr = np.ndarray(shape, dtype=data.dtype, buffer=shm.buf)
    print("Copying data to shared memory...")
    shared_arr[:] = data[:]
    print("Copy done")
    end = np.size(data) -1

    #Multiprocessing with a single thread for test reasons
    # Create pool and let the pool processes terminate gracefully
    # by calling pool.close() followed by pool.join() rather than the implicit
    # call to pool.terminate() that is done when the with mp.pool(...) as pool: block terminates:
    pool = mp.Pool(processes=1, initializer=worker_init, initargs=(shm.name, shape, data.dtype))
    print("Workers ready, submitting work...")
    result = pool.starmap(worker_func, [(0, end)])  # Your batch_size
    print(f"Got result, shape: {result[0].shape}")
    # Terminate pool gracefully to ensure shared memory is closed by each process:
    pool.close()
    pool.join()

    shm.close()
    shm.unlink()
    print("Done")

Prints:

Copying data to shared memory...
Copy done
Workers ready, submitting work...
Worker init...
accessing shared memory
Worker init done, shape: (10,)
Worker processing 0:9
Array shape: (10,)
Slicing [0:10]...
Slice shape: (10,)
Got result, shape: (10,)
shared memory is being closed
Done

Update (Improved Solution)

I have since looked at the code for module multiprocessing.shared_memory for Python 3.12 and I see that the SharedMemory class has a __del__ method that does issue a close. So when a reference to shared memory is garbage collected the close will automatically be done. I don't know in what release this logic was introduced; it might have been there since day one. Consequently, if the version of Python you are running has this logic (and my guess is that it does) then:

  1. That definitely explains why your initial code did not work when your initializer was not saving a reference to shared memory in a global. It is still not clear to me why the actual pool process terminates, which does not normally occur just because a task it is running raises an exception.
  2. Saving a reference to shared memory as a global in your initializer is necessary and might be sufficient to resolve your problem; there would be no need to use the SharedMemoryWrapper class I introduced in the More Info section. But can we always guarantee that shared memory's __del__ method will be invoked? It will not if you use your pool instance as a context manager as you were doing because of the implicit call to pool.terminate() that is done. Therefore, the best solution to ensure that the handle is always closed is to register a function to be called whenever the pool process terminates:
import atexit

...

def worker_init(shm_name, shape, dtype):
    global test_array, shm
    print("Worker init...")
    shm = shared_memory.SharedMemory(name=shm_name)

    # Ensure the handle is closed:
    atexit.register(shm.close)

    test_array = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
    print(f"Worker init done, shape: {test_array.shape}")
Sign up to request clarification or add additional context in comments.

2 Comments

Resolved the issue both in the example and in the main code. Thank you!
Please look at the added Update section I posted to my answer.
-1

On Windows, multiprocessing.Pool uses spawn, so each worker runs the initializer separately. Returning slices from workers can hang because large arrays are pickled.

For GB-sized arrays, use multiprocessing.shared_memory and modify arrays in-place to avoid pickling overhead. This approach is efficient and works reliably across processes.

2 Comments

The issue isn't that "init" is being called multiple times, it's that the worker is hanging. And the array in this example isn't large, it's only 10 complex numbers. There shouldn't be any issue with the array being too big.
As it’s currently written, your answer is unclear. Please edit to add additional details that will help others understand how this addresses the question asked. You can find more information on how to write good answers in the help center.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.