import unittest

from rdflib import Graph

import extract
import transform
from common import *


class TestObjectCount(unittest.TestCase):
    def __init__(self, methodName):  # pyright: ignore[reportMissingParameterType]
        super().__init__()
        # Run extraction
        extract.main()
        transform.main()
        # Load graph
        self.graph = Graph()
        self.graph.bind("", NS)
        self.graph.parse(RDF_FULL_FILE)

    def runTest(self):
        # Modules
        self.assertCount(
            """SELECT * WHERE {
                ?subj a :Module .

                ?subj :id ?id .
                ?subj :index ?index .
                ?subj :titre ?titre .
                
                MINUS { ?subj a :SousPartie }
                }""",
            (9, 6, 9 + 6),
        )
        # SousParties
        self.assertCount(
            """SELECT * WHERE {
                ?subj a :SousPartie .

                ?subj :id ?id .
                ?subj :index ?index .
                ?subj :titre ?titre .
                }""",
            (18, 14, 18 + 14),
        )
        # Activités
        self.assertCount(
            """SELECT * WHERE {
                ?subj a :Activite .

                ?subj :id ?id .
                ?subj :index ?index .
                ?subj :titre ?titre .
                ?subj :description ?desc .
                }""",
            (132, 86, 132 + 86),
        )

        # Types d'activités:
        # Cours
        self.assertCount(
            """SELECT * WHERE {
                ?subj a :Cours .

                ?subj :description ?desc .
                }""",
            (59, 26, 59 + 26),
        )
        # QCU
        self.assertCount(
            """SELECT DISTINCT ?subj ?desc WHERE {
                ?subj a :ExerciceQC_QCU .

                ?subj :aReponse ?rep .
                }""",
            (39, 25, 39 + 25),
        )
        # QCM
        self.assertCount(
            """SELECT DISTINCT ?subj ?desc WHERE {
                ?subj a :ExerciceQC_QCM .

                ?subj :aReponse ?rep .
                }""",
            (9, 6, 9 + 6),
        )
        # QM
        self.assertCount(
            """SELECT DISTINCT ?subj ?desc WHERE {
                ?subj a :ExerciceQM .
                }""",
            (8, 3, 8 + 3),
        )
        # TAT
        self.assertCount(
            """SELECT DISTINCT ?subj ?desc WHERE {
                ?subj a :ExerciceTAT .
                }""",
            (12, 25, 12 + 25),
        )
        # GD
        self.assertCount(
            """SELECT DISTINCT ?subj ?desc WHERE {
                ?subj a :ExerciceGD .
                }""",
            (5, 1, 5 + 1),
        )

        # Autres entités
        # Réponses
        self.assertCount(
            """SELECT * WHERE {
                ?subj a :Reponse .

                ?subj :id ?id .
                ?subj :index ?index .
                ?subj :correct ?correct .
                ?subj :html ?html .
                }""",
            # The minus values are to account for missed gaps in TAT activities
            # (see warnings when running the extraction), which are caused
            # by a known but tricky bug
            (258 - 30, 161, 258 - 30 + 161),
        )
        # Segments TAT
        self.assertCount(
            """SELECT * WHERE {
                ?subj a :Segment .

                ?subj :index ?index .
                ?subj :text ?text .
                MINUS { ?subj a :Champ }
                }""",
            (28, 42, 28 + 42),
        )
        # Champs TAT
        self.assertCount(
            """SELECT * WHERE {
                ?subj a :Champ ;
                      a :Segment .

                ?subj :index ?index .
                ?subj :selection ?selection .
                }""",
            (16, 18, 16 + 18),
        )

    def assertCount(
        self, query: str, expected_tuple: tuple[int | None, int | None, int | None]
    ):
        """Checks that the `query` produces the expected number of results.
        The `expected_tuple` contains 3 values, for the macao_12 graph,
        macao_3 graph, and both combined. A `None` value in the tuple ignores this check.
        """
        res = self.graph.query(query)
        count = len(res)
        # Check that variables bound to Literals are not empty or None
        for binding in res.bindings:
            for var, val in binding.items():
                if isinstance(val, Literal):
                    self.assertFalse(
                        val.eq(Literal("")), f"Empty value: ?{var} = '{val}'"
                    )
                    self.assertFalse(
                        val.eq(Literal("None")), f"None value: ?{var} = '{val}'"
                    )

        versions = ("macao_12", "macao_3", "full")
        try:
            expected = expected_tuple[versions.index(MACAO_VERSION)]
        except ValueError:
            self.fail(f"Unknown version '{Context.version}'")
        if expected is not None:
            self.assertEqual(count, expected)
        # else skip test


if __name__ == "__main__":
    unittest.main()
