diff --git a/amrbatch/main.py b/amrbatch/main.py index 742a2e4cc979835a06280d0831108939bea861e2..a50a857507ba97a3cb34b92420175008423839c9 100644 --- a/amrbatch/main.py +++ b/amrbatch/main.py @@ -16,6 +16,7 @@ import traceback import logging import multiprocessing_logging import multiprocessing +from multiprocessing import Manager from amrlib.graph_processing.amr_plot import AMRPlot from filepath_manager import FilepathManager @@ -87,23 +88,28 @@ def __generate_sentence_file(filepath_manager, workdata_list): if not first: writing_file.write("\n") writing_file.write(workdata.sentence) first = False - - - - - #============================================================================== # Sentence Conversion to AMR #============================================================================== +# Function executed when a worker is created in the pool +def init_pool_worker(): + # declare scope of a new global variable + global stog + amr_model_path = "/home/daxid/hdd_data/jupyterlab_root/lib/amrModel/model_parse_xfm_bart_large-v0_1_0" + # store argument in the global variable for this process + logger.info("-- Loading AMR model") + stog = amrlib.load_stog_model(model_dir=amr_model_path) + + def __run_conversion(arg_dict): data = arg_dict['data'] amr_model_path = arg_dict['amr_model_path'] - logger.info("-- Loading AMR model") - stog = amrlib.load_stog_model(model_dir=amr_model_path) +# logger.info("-- Loading AMR model") +# stog = amrlib.load_stog_model(model_dir=amr_model_path) logger.info("-- Converting sentences to AMR graphs") stog_result = stog.parse_sents([data.sentence]) @@ -116,6 +122,8 @@ def __run_conversion(arg_dict): def __convert_sentences_to_graphs(amr_model_path, input_data_list): """ Converting text sentences to AMR graphs """ + + global stog mapIterable = [] for data in input_data_list: @@ -123,7 +131,8 @@ def __convert_sentences_to_graphs(amr_model_path, input_data_list): mapIterable = mapIterable + [arg_dict] number_of_processes = min(round((multiprocessing.cpu_count()-1)/4), len(input_data_list)) - with multiprocessing.Pool(number_of_processes) as p: + #with multiprocessing.Pool(number_of_processes) as p: + with multiprocessing.Pool(number_of_processes, initializer=init_pool_worker) as p: result_data_list = p.map(__run_conversion, mapIterable) # result_data_list = []