TL/DR
The reason why your worker_init function is being called more times than the number of processes you have in your pool, which is 1, is because your pool process is abnormally terminating and the pool needs to then re-create the pool process resulting in calling the initializer again. The reason why the process is terminating is because your initializer is not retaining in global storage a reference to shared memory. If you were to add a global shm statement to your worker_init function that would resolve that issue. The other problem is that your call to starmap returns a list, which in your case contains a single np array element, yet you are calling shape on this list rather than on an element of the list.
More Info
It may well be that when a shared memory handle is garbage collected a close is automatically called on the handle. If so, that would explain why your pool process might be abnormally terminating, but I don't think that is the case. So to ensure shared memory is released properly on all platforms to prevent memory leaks, we should ensure that a close is issued on the handle when the process terminates:
- Wrap shared memory is a special wrapper class
SharedMemoryWrapper instance, which contains a __del__ function that will close the shared memory handle when the wrapper instance is garbage collected and that wrapper instance should be stored so that each pool process correctly maintains a reference to its own wrapper instance.
- You are using a
with Pool(...) as pool: context manager. When this block exits an implicit call to pool.terminate() will be called to terminate each process. This could prevent the shared memory wrapper from being properly garbage collected and thus a close would not be closed on the shared memory handle. Instead, gracefully "terminate" the pool by classing pool.close() followed by pool.join() allowing each pool process to end normally.
When I make these changes your code runs correctly and terminates as expected.
import multiprocessing as mp
from multiprocessing import shared_memory
import numpy as np
class SharedMemoryWrapper:
def __init__(self, shm_name, shape, dtype):
print('accessing shared memory')
self._shared_mem = shared_memory.SharedMemory(name=shm_name)
self._arr = np.ndarray(shape, dtype=dtype, buffer=self._shared_mem.buf)
@property
def arr(self):
return self._arr
def __del__(self):
print('shared memory is being closed')
self._shared_mem.close()
def worker_func(start, end):
global test_array
print(f"Worker processing {start}:{end}")
print(f"Array shape: {test_array.shape}")
print(f"Slicing [{start}:{end + 1}]...")
slice_result = test_array[start:end + 1]
print(f"Slice shape: {slice_result.shape}")
return slice_result
def worker_init(shm_name, shape, dtype):
global test_array
global shared_memory_wrapper
print("Worker init...")
shared_memory_wrapper = SharedMemoryWrapper(shm_name, shape, dtype)
test_array = shared_memory_wrapper.arr
print(f"Worker init done, shape: {test_array.shape}")
if __name__ == "__main__":
shape = (10,) #whatever, fails for both 1 dim and high dimensional arrays
data = np.random.randn(*shape) + 1j * np.random.randn(*shape)
data = data.astype(np.complex128)
# Put in shared memory
shm = shared_memory.SharedMemory(create=True, size=data.nbytes)
shared_arr = np.ndarray(shape, dtype=data.dtype, buffer=shm.buf)
print("Copying data to shared memory...")
shared_arr[:] = data[:]
print("Copy done")
end = np.size(data) -1
#Multiprocessing with a single thread for test reasons
# Create pool and let the pool processes terminate gracefully
# by calling pool.close() followed by pool.join() rather than the implicit
# call to pool.terminate() that is done when the with mp.pool(...) as pool: block terminates:
pool = mp.Pool(processes=1, initializer=worker_init, initargs=(shm.name, shape, data.dtype))
print("Workers ready, submitting work...")
result = pool.starmap(worker_func, [(0, end)]) # Your batch_size
print(f"Got result, shape: {result[0].shape}")
# Terminate pool gracefully to ensure shared memory is closed by each process:
pool.close()
pool.join()
shm.close()
shm.unlink()
print("Done")
Prints:
Copying data to shared memory...
Copy done
Workers ready, submitting work...
Worker init...
accessing shared memory
Worker init done, shape: (10,)
Worker processing 0:9
Array shape: (10,)
Slicing [0:10]...
Slice shape: (10,)
Got result, shape: (10,)
shared memory is being closed
Done
Update (Improved Solution)
I have since looked at the code for module multiprocessing.shared_memory for Python 3.12 and I see that the SharedMemory class has a __del__ method that does issue a close. So when a reference to shared memory is garbage collected the close will automatically be done. I don't know in what release this logic was introduced; it might have been there since day one. Consequently, if the version of Python you are running has this logic (and my guess is that it does) then:
- That definitely explains why your initial code did not work when your initializer was not saving a reference to shared memory in a global. It is still not clear to me why the actual pool process terminates, which does not normally occur just because a task it is running raises an exception.
- Saving a reference to shared memory as a global in your initializer is necessary and might be sufficient to resolve your problem; there would be no need to use the
SharedMemoryWrapper class I introduced in the More Info section. But can we always guarantee that shared memory's __del__ method will be invoked? It will not if you use your pool instance as a context manager as you were doing because of the implicit call to pool.terminate() that is done. Therefore, the best solution to ensure that the handle is always closed is to register a function to be called whenever the pool process terminates:
import atexit
...
def worker_init(shm_name, shape, dtype):
global test_array, shm
print("Worker init...")
shm = shared_memory.SharedMemory(name=shm_name)
# Ensure the handle is closed:
atexit.register(shm.close)
test_array = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
print(f"Worker init done, shape: {test_array.shape}")
shm(in theworker_init) gets garbage collected while it's still being used viatest_array. How about storingshmas a global variable along withtest_array?worker_initwill be called for each pool process created. If you specifyPool(processes=1, ...)it will only be called once. Are you sure you didn't have at some point in your codeprocesses=2?