#!/usr/bin/python3.10
# -*-coding:Utf-8 -*

#==============================================================================
# C.M. Tool: PropBank Frame Analyzer
#------------------------------------------------------------------------------
# Module to analyze PropBank frames
#==============================================================================

#==============================================================================
# Importing required modules
#==============================================================================

import sys
import glob
import re

from bs4 import BeautifulSoup


#==============================================================================
# Parameters
#==============================================================================

# Input/Output Directories
INPUT_DIR = "../inputData/"
OUTPUT_DIR = "../outputData/"

# Data
PROPBANK_FRAMES_DIR = "../propbankFrames/"
PBF_DIGITS = 2
AMR_CORE_ROLE_FORM = [':ARG\d$', 'ARG\d$', '\d$']


#==============================================================================
# Functions to analyze and adapt the target description
#==============================================================================

def itemize_amr_predicate(amr_predicate):
    ap_items = amr_predicate.split('-')
    lemma = ap_items[0]
    if len(ap_items) > 1:
        roleset_number = int(ap_items[1])
    else:
        roleset_number = 1
    return lemma, roleset_number


def get_lemma_from_amr_predicate(amr_predicate):
    lemma, _ = itemize_amr_predicate(amr_predicate)
    return lemma
    

def get_role_ref_from_amr_predicate(amr_predicate):
    _, roleset_number = itemize_amr_predicate(amr_predicate)
    roleset_ref = str(roleset_number).rjust(PBF_DIGITS,"0")
    return roleset_ref
    
    
def get_roleset_id_from_amr_predicate(amr_predicate):
    lemma = get_lemma_from_amr_predicate(amr_predicate)
    roleset_ref = get_role_ref_from_amr_predicate(amr_predicate)
    roleset_id = lemma + '.' + roleset_ref    
    return roleset_id


def get_number_from_amr_role(amr_role):
    role_number = -1
    for role_format in AMR_CORE_ROLE_FORM:
        if re.match(role_format, amr_role):
            role_number = int(amr_role[-1])
    return role_number


#==============================================================================
# Functions to find the XML description corresponding to a roleset
#==============================================================================

def find_frame_of_lemma(lemma):
    """ Find the Frame XML data corresponding to a given lemma
    """
    
    target_file = PROPBANK_FRAMES_DIR + lemma + '.xml'
    frame_filepath = glob.glob(target_file, recursive=True)
    
    if len(frame_filepath) >= 1:
        frame_filepath = frame_filepath[0]
        with open(frame_filepath, 'r') as f:
            xml_data = f.read()
            frame_data = BeautifulSoup(xml_data, 'xml')
    else:
        frame_filepath = ''
        frame_data = None
    
    is_found = frame_data is not None
    
    return is_found, frame_filepath, frame_data


#==============================================================================
# Functions to analyze a frame data
#==============================================================================

def find_roleset_in_frame(frame_data, lemma, roleset_id):
    """ Find the roleset corresponding to a lemma and an id in a frame data
    """
      
    try:
        lemma_data = frame_data.find('predicate', {'lemma':lemma})
        roleset_data = lemma_data.find('roleset', {'id':roleset_id})
    
    except:
        lemma_data = None
        roleset_data = None
         
    is_found = (lemma_data is not None) & (roleset_data is not None)
    
    return is_found, roleset_data


def find_role_in_roleset(roleset_data, role_number):
    """ Find the role corresponding to a given number in a roleset data
    """
  
    try:
        role_data = roleset_data.find('role', {'n':role_number})

    except:
        role_data = None

    is_found = (role_data is not None)

    return is_found, role_data


#==============================================================================
# Main Function(s)
#==============================================================================

def find_pb_role(amr_predicate, amr_role):
    """
    Find the probbank role in PropBank frame corresponding to a given AMR
    predicate and a given AMR role.

    Parameters
    ----------
    amr_predicate : STRING
        AMR predicate (example: 'include-01').
    amr_role : STRING
        AMR core role (example: ':ARG0').

    Returns
    -------
    PropBank role (example: 'PAG').

    """
    
    # -- Intialize result
    result = None
    
    # -- Analyze and adapt the target description
    lemma = get_lemma_from_amr_predicate(amr_predicate)
    roleset_id = get_roleset_id_from_amr_predicate(amr_predicate)
    role_number = get_number_from_amr_role(amr_role)
    
    # -- Find the Frame XML data corresponding to a given lemma
    frame_found, frame_filepath, frame_data = find_frame_of_lemma(lemma)
           
    if frame_found:      
        # -- Analyze frame data to find the target role
        rs_found, rs_data = find_roleset_in_frame(frame_data, lemma, roleset_id)
        nb_roles = -1
        
        if rs_found:
            nb_roles = len(rs_data.find_all('role'))
            if role_number in range(nb_roles):
                r_found, role_data = find_role_in_roleset(rs_data, role_number)
                if r_found:
                    result = role_data.get('f')   
    
    return result
    

#==============================================================================
# *** Dev Test ***
#==============================================================================

def dev_analyze(amr_predicate, amr_role):

    print("\n" + "[CMT] PropBank Frame Analyzer")
    
    # -- Analyze and adapt the target description
    print("-- Analyzing given data to specify the targetted data")
    print("----- given data: " + amr_predicate + ', ' + amr_role)
    lemma = get_lemma_from_amr_predicate(amr_predicate)
    print("----- lemma: " + lemma)
    roleset_id = get_roleset_id_from_amr_predicate(amr_predicate)
    print("----- roleset id: " + roleset_id)
    role_number = get_number_from_amr_role(amr_role)
    print("----- role number: " + str(role_number))
    
    # -- Find the Frame XML data corresponding to a given lemma
    print("-- Finding frame data")
    frame_found, frame_filepath, frame_data = find_frame_of_lemma(lemma)
    if frame_found:
        print("----- frame xml file found: " + frame_filepath)
    else:
        print("----- frame xml file not found for lemma " + lemma)
           
    if frame_found:      
        # -- Analyze frame data to get informations
        print("-- Analyzing frame data")
        rs_found, rs_data = find_roleset_in_frame(frame_data, lemma, roleset_id)
        nb_roles = -1
        
        if rs_found:
            print("----- roleset id: " + rs_data.get('id'))
            print("----- roleset name: " + rs_data.get('name'))
            nb_roles = len(rs_data.find_all('role'))
            print("----- number of roles: " + str(nb_roles))
            for n in range(nb_roles):
                _, role_data = find_role_in_roleset(rs_data, n)
                print("----- role " + str(n) + ': ' + role_data.get('f') + 
                                                 ', ' + role_data.get('descr'))
        else:
            print("----- roleset " + roleset_id + " not found")
        
        # -- Analyze frame data to get informations
        if rs_found & role_number in range(nb_roles):
            print("-- Finding role")
            print("----- role number: " + str(role_number))
            r_found, role_data = find_role_in_roleset(rs_data, role_number)
            if r_found:
                print("----- role " + str(role_number) + " found: " + 
                      role_data.get('f') + 
                      ', ' + role_data.get('descr'))
            else:
                print("----- role " + str(role_number) + " not found")    
    
    # -- Test for main function(s)
    print("-- Test for main function(s)")
    pb_role = find_pb_role(amr_predicate, amr_role)
    print("----- find_pb_role(amr_predicate, amr_role) = " + pb_role)
    
    # -- Ending print
    print("\n" + "[SSC] Done")
    
def dev_test_1():
    dev_analyze('include-01', ':ARG0')