# Compares some results (e.g. output of Tenet) w.r.t a corpus
# with the expected ODRL representations of the corpus
# 
# ODRL files in the corpus directory should have the same name as in the
# output directory

import os
import sys
from rdflib import Graph

class ODRL:
	
	def __init__(self):
		self.has_perm = False
		self.has_obl = False
		self.has_proh = False
		self.perm = []
		self.obl = []
		self.proh = []
	
	""" Parses the content of an ODRL file and adds its contents to the ODRL """
	def parse(self, odrl_fname):
		graph = Graph()
		graph.parse(odrl_fname)
		odrl_ns = "http://www.w3.org/ns/odrl/2/"
		ccrel_ns = "http://creativecommons.org/ns#"
		perm_node = None
		obl_node = None
		proh_node = None
		actions = {} # action listing for every blank node
		
		# List actions and modalities
		for (src, prop, tgt) in graph:
			src, prop, tgt = str(src), str(prop), str(tgt)
			if prop == odrl_ns + "permission":
				self.has_perm = True
				perm_node = tgt
			elif prop == odrl_ns + "obligation":
				self.has_obl = True
				obl_node = tgt
			elif prop == odrl_ns + "prohibition":
				self.has_proh = True
				proh_node = tgt
			elif prop == odrl_ns + "action":
				if src not in actions: actions[src] = []
				actions[src].append(tgt)
		
		# Link modalities with actions
		for node in actions:
			if node == perm_node: mod = self.perm
			elif node == obl_node: mod = self.obl
			elif node == proh_node: mod = self.proh
			else: print("Warning: ill-formed ODRL")
			for act in actions[node]:
				mod.append(act)
	
	def actions(self):
		acts = []
		for act in self.perm: acts.append(act)
		for act in self.obl: acts.append(act)
		for act in self.proh: acts.append(act)
		return acts
	
	def __str__(self):
		s = ""
		if self.has_perm:
			s += "Permissions:\n"
			for act in self.perm:
				s += f"\t{act}\n"
		if self.has_obl:
			s += "Obligations:\n"
			for act in self.obl:
				s += f"\t{act}\n"
		if self.has_proh:
			s += "Prohibitions:\n"
			for act in self.proh:
				s += f"\t{act}\n"
		return s


class Scores:

	def __init__(self):
		self.scores = {}
		
	def check_and_add_criterion(self, crit):
		if crit not in self.scores:
			self.scores[crit] = {"tp": 0, "fp": 0, "fn": 0}
	
	""" Adds a single measure on the score.
	    tgt (bool): real value
	    output (bool): measured value """
	def add_measure(self, crit, tgt, output):
		self.check_and_add_criterion(crit)
		if tgt and output:
			self.scores[crit]["tp"] += 1
		elif not tgt and output:
			self.scores[crit]["fp"] += 1
		elif tgt and not output:
			self.scores[crit]["fn"] += 1
	
	def get_precision(self, crit):
		return self.scores[crit]["tp"]/(self.scores[crit]["tp"] + self.scores[crit]["fp"])
	
	def get_recall(self, crit):
		return self.scores[crit]["tp"]/(self.scores[crit]["tp"] + self.scores[crit]["fn"])
	
	def get_total_precision(self):
		tot_tp = 0
		tot_fp = 0
		for crit in self.scores:
			tot_tp += self.scores[crit]["tp"]
			tot_fp += self.scores[crit]["fp"]
		return tot_tp/(tot_tp + tot_fp)
			
	def get_total_recall(self):
		tot_tp = 0
		tot_fn = 0
		for crit in self.scores:
			tot_tp += self.scores[crit]["tp"]
			tot_fn += self.scores[crit]["fn"]
		return tot_tp/(tot_tp + tot_fn)
	
	def __str__(self):
		s = ""
		for crit in self.scores:
			s += f"{crit}:\n"
			s += f'\tTP = {self.scores[crit]["tp"]}, FP = {self.scores[crit]["fp"]}, FN = {self.scores[crit]["fn"]}\n'
			s += f'\tPrecision = {self.get_precision(crit)}, Recall = {self.get_recall(crit)}\n'
		s += f'Total precision = {self.get_total_precision()}\n'
		s += f'Total recall = {self.get_total_recall()}\n'
		return s
		


def add_comp_modalities(odrl_target, odrl_output, scores):
	scores.add_measure("permission", odrl_target.has_perm, odrl_output.has_perm)
	scores.add_measure("obligation", odrl_target.has_obl, odrl_output.has_obl)
	scores.add_measure("prohibition", odrl_target.has_proh, odrl_output.has_proh)

def add_comp_actions(odrl_target, odrl_output, scores):
	actions_target = odrl_target.actions()
	actions_output = odrl_output.actions()
	
	for action in actions_output:
		scores.add_measure(action, action in actions_target, True)
	for action in actions_target:
		if action not in actions_output: # Don't reconsider True, True
			scores.add_measure(action, True, False)

def add_comp_global(odrl_target, odrl_output, scores):
	same = set(odrl_target.perm) == set(odrl_output.perm)
	same = same and set(odrl_target.obl) == set(odrl_output.obl)
	same = same and set(odrl_target.proh) == set(odrl_output.proh)
	scores.add_measure("global", True, same)


""" path_target: path to the ODRL files of the corpus
    path_output: path to the ODRL files of the output """
def main_scorer(path_target, path_output):
	
	scores_modalities = Scores()
	scores_actions = Scores()
	scores_global = Scores()
	
	for fname in os.listdir(path_odrl):
		if fname.endswith(".ttl"):
		
			odrl_target = ODRL()
			odrl_target.parse(path_odrl + fname)
			odrl_output = ODRL()
			odrl_output.parse(path_output + fname)
			
			add_comp_modalities(odrl_target, odrl_output, scores_modalities)
			add_comp_actions(odrl_target, odrl_output, scores_actions)
			add_comp_global(odrl_target, odrl_output, scores_global)
	
	return scores_modalities, scores_actions, scores_global



if __name__ == "__main__":

	if len(sys.argv) < 3:
		print(f"Usage: python3 {sys.argv[0]} <path_corpus_odrl> <path_output_odrl")
		exit(1)

	path_target = sys.argv[1]
	path_output = sys.argv[2]
	
	scores_modalities, scores_actions, scores_global = main_scorer(path_target, path_output)
	
	print(f"Modality scores:\n{scores_modalities}")
	print(f"Actions scores:\n{scores_actions}")
	print(f"Global scores:\n{scores_global}")