diff --git a/amrbatch/main.py b/amrbatch/main.py index 075ae680cc19018448154bae8714bb49f09258aa..baf568e138b1bef27b0c29578742d2f53372f487 100644 --- a/amrbatch/main.py +++ b/amrbatch/main.py @@ -13,7 +13,9 @@ import subprocess import amrlib from rdflib import Graph import traceback -import logging.config +import logging +import multiprocessing_logging +import multiprocessing from amrlib.graph_processing.amr_plot import AMRPlot from filepath_manager import FilepathManager @@ -91,18 +93,17 @@ def __generate_sentence_file(filepath_manager, workdata_list): #============================================================================== def __generate_penman_amr_graph(filepath_manager, data): - """ AMR graph generation in penman format """ - + """ AMR graph generation in penman format """ output_filepath = data.get_penman_amr_graph_output_filepath() logger.debug(f"----- AMR Graph file (penman): {os.path.basename(output_filepath)}") with open(output_filepath, "w") as writing_file: # w = write writing_file.write(data.id_line_str) writing_file.write(data.graph) + return(output_filepath) def __generate_dot_amr_graph(filepath_manager, data): - """ AMR graph generation in dot and png format """ - + """ AMR graph generation in dot and png format """ try: # -- generating dot/png/svg files using AMRLib and GraphViz dot_filename = data.get_dot_amr_graph_output_filepath() @@ -115,7 +116,7 @@ def __generate_dot_amr_graph(filepath_manager, data): good_png_fn = data.get_png_amr_graph_output_filepath() logger.debug(f'----- AMR Graph file (png): {{os.path.basename(good_png_fn)}}') os.rename(render_fn, good_png_fn) - + returnValue = dot_filename format = 'svg' plot = AMRPlot(dot_filename, format) plot.build_from_graph(data.graph) @@ -124,37 +125,51 @@ def __generate_dot_amr_graph(filepath_manager, data): good_svg_fn = good_png_fn.replace('.png','.svg') logger.debug(f'----- AMR Graph file (svg): {{os.path.basename(good_svg_fn)}}') os.rename(render_fn, good_svg_fn) - - - except: - logger.warning('Exception when trying to plot') + except Exception as ex: + logger.warning('Exception when trying to plot: '+ex) traceback.print_exc() - - -def __convert_sentences_to_graphs(amr_model, workdata_list): - """ Converting text sentences to AMR graphs """ - + returnValue = 'Exception when trying to plot' + return(returnValue) + +# Function executed when a worker is created in the pool +def init_pool_worker(): + amr_model_path = "/home/daxid/hdd_data/jupyterlab_root/lib/amrModel/model_parse_xfm_bart_large-v0_1_0" + # declare scope of a new global variable + global stog + # store argument in the global variable for this process logger.info("-- Loading AMR model") stog = amrlib.load_stog_model(model_dir=amr_model) - logger.info("-- Converting sentences to AMR graphs") - wd_number = 0 - for data in workdata_list: - wd_number += 1 - stog_result = stog.parse_sents([data.sentence]) - logger.info(f'----- Sentence {wd_number} successfully processed') - logger.debug(stog_result) - data.graph = stog_result[0] +def __convert_sentence_to_graph_multiprocess_run(data): + print("in worker\n") + wd_number = 1 + stog_result = stog.parse_sents([data.sentence]) + logger.info(f'----- Sentence {wd_number} successfully processed') + logger.debug(stog_result) + data.graph = stog_result[0] + return(stog_result) - logger.info(f'----- Total processed graph number: {wd_number}') +def __convert_sentences_to_graphs(amr_model_path, workdata_list): + """ Converting text sentences to AMR graphs """ + # ----- (Multi-processing) Extraction Run + number_of_processes = min(multiprocessing.cpu_count()-1, len(workdata_list)) + global stog + with multiprocessing.Pool(2, initializer=init_pool_worker) as p: + logger.info("-- Converting sentences to AMR graphs") + print("pool created\n") + stog_result_list = p.map(__convert_sentence_to_graph_multiprocess_run, workdata_list) + logger.info(f'----- Total processed graph number: {len(stog_result_list)}') return workdata_list def __generate_amr_graph_files(filepath_manager, workdata_list): logger.info("-- Generating AMR graph files") - for data in workdata_list: - __generate_penman_amr_graph(filepath_manager, data) - __generate_dot_amr_graph(filepath_manager, data) + # ----- Prepare multiprocessing data + starmapIterable = [(data,filepath_manager) for data in workdata_list] + # ----- (Multi-processing) Extraction Run + with multiprocessing.Pool(multiprocessing.cpu_count()-1) as p: + penmanFilePathList = p.starmap(__generate_penman_amr_graph, starmapIterable) + dotFilePathList = p.starmap(__generate_dot_amr_graph, starmapIterable) @@ -237,6 +252,10 @@ def __analyze_line_set_to_produce_amr_graphs(line_set, data_reference, amr_model logger.info(f'-- library: amrlib') logger.info(f'-- model: {os.path.basename(amr_model_path)}') logger.debug(f' ({amr_model_path})') + + # ----- Multiprocessing Logging (must be exec before the pool is created) + multiprocessing_logging.install_mp_handler() + workdata_list = __convert_sentences_to_graphs(amr_model_path, workdata_list) __generate_amr_graph_files(filepath_manager, workdata_list)