From 2224211a758f4965d0f7f825baabc97acb2c030e Mon Sep 17 00:00:00 2001 From: daxid <david.rouquet@tetras-libre.fr> Date: Wed, 28 Jun 2023 11:32:32 +0200 Subject: [PATCH] Multiprocessing run with initializer --- amrbatch/main.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/amrbatch/main.py b/amrbatch/main.py index 742a2e4c..a50a8575 100644 --- a/amrbatch/main.py +++ b/amrbatch/main.py @@ -16,6 +16,7 @@ import traceback import logging import multiprocessing_logging import multiprocessing +from multiprocessing import Manager from amrlib.graph_processing.amr_plot import AMRPlot from filepath_manager import FilepathManager @@ -87,23 +88,28 @@ def __generate_sentence_file(filepath_manager, workdata_list): if not first: writing_file.write("\n") writing_file.write(workdata.sentence) first = False - - - - - #============================================================================== # Sentence Conversion to AMR #============================================================================== +# Function executed when a worker is created in the pool +def init_pool_worker(): + # declare scope of a new global variable + global stog + amr_model_path = "/home/daxid/hdd_data/jupyterlab_root/lib/amrModel/model_parse_xfm_bart_large-v0_1_0" + # store argument in the global variable for this process + logger.info("-- Loading AMR model") + stog = amrlib.load_stog_model(model_dir=amr_model_path) + + def __run_conversion(arg_dict): data = arg_dict['data'] amr_model_path = arg_dict['amr_model_path'] - logger.info("-- Loading AMR model") - stog = amrlib.load_stog_model(model_dir=amr_model_path) +# logger.info("-- Loading AMR model") +# stog = amrlib.load_stog_model(model_dir=amr_model_path) logger.info("-- Converting sentences to AMR graphs") stog_result = stog.parse_sents([data.sentence]) @@ -116,6 +122,8 @@ def __run_conversion(arg_dict): def __convert_sentences_to_graphs(amr_model_path, input_data_list): """ Converting text sentences to AMR graphs """ + + global stog mapIterable = [] for data in input_data_list: @@ -123,7 +131,8 @@ def __convert_sentences_to_graphs(amr_model_path, input_data_list): mapIterable = mapIterable + [arg_dict] number_of_processes = min(round((multiprocessing.cpu_count()-1)/4), len(input_data_list)) - with multiprocessing.Pool(number_of_processes) as p: + #with multiprocessing.Pool(number_of_processes) as p: + with multiprocessing.Pool(number_of_processes, initializer=init_pool_worker) as p: result_data_list = p.map(__run_conversion, mapIterable) # result_data_list = [] -- GitLab