# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import importlib
import inspect
import json
import os
import sys

from nsys_recipe import nsys_constants
from nsys_recipe.lib import recipe
from nsys_recipe.log import logger
from nsys_recipe.nsys_constants import NSYS_RECIPE_INSTALL_PATH, NSYS_RECIPE_REQ_PATH


def get_metadata_path(recipe_dir, recipe_name):
    return os.path.join(recipe_dir, recipe_name, "metadata.json")


def get_metadata_dict(recipe_dir, recipe_name):
    json_path = get_metadata_path(recipe_dir, recipe_name)
    if not os.path.exists(json_path):
        return None

    with open(json_path) as f:
        return json.load(f)


def is_recipe_subclass(obj):
    return (
        inspect.isclass(obj) and issubclass(obj, recipe.Recipe) and obj != recipe.Recipe
    )


def get_recipe_class_from_module(module, module_name, metadata_path, class_name=None):
    if class_name is not None:
        try:
            recipe_class = getattr(module, class_name)
        except AttributeError:
            logger.error(
                f"The '{module_name}' module does not contain the '{class_name}' class."
                f" Please update the '{metadata_path}' file."
            )
            return None

        if is_recipe_subclass(recipe_class):
            return recipe_class

        logger.error(
            f"The '{class_name}' class specified in the 'class_name' field is not a Recipe class."
            f" Please update the '{metadata_path}' file."
        )
        return None

    members = inspect.getmembers(module, is_recipe_subclass)
    if not members:
        logger.error(
            f"No Recipe class found in the '{module_name}' module."
            f" Please update the '{metadata_path}' file."
        )
        return None

    name, recipe_class = members[0]
    if len(members) > 1:
        logger.warning(
            f"Multiple Recipe classes detected. Using the first class '{name}' as default."
            f" To choose a different class, please specify the 'class_name' field in the {metadata_path} file."
        )

    return recipe_class


def import_module_from_path(module_name, file_path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


def get_recipe_module(search_path, recipe_name, module_name, metadata_path):
    if search_path not in sys.path:
        sys.path.append(search_path)

    recipe_module_name = f"{recipe_name}.{module_name}"
    recipe_module_path = os.path.join(search_path, recipe_name, module_name)

    if not os.path.exists(recipe_module_path):
        if os.path.exists(f"{recipe_module_path}.py"):
            recipe_module_path += ".py"
        else:
            logger.error(
                f"'{recipe_module_path}' not found."
                f" Please make sure that the 'module_name' field in the '{metadata_path}' file points to a valid module."
            )
            return None

    try:
        # Due to the lack of an __init__.py file in the different recipe
        # directories, the recipes aren't recognized as modules. To address
        # this, we dynamically load the recipe module from the exact file path.
        return import_module_from_path(recipe_module_name, recipe_module_path)
    except ModuleNotFoundError as e:
        req_file = os.path.join(NSYS_RECIPE_REQ_PATH, recipe_name + ".txt")
        install_file = NSYS_RECIPE_INSTALL_PATH
        if os.path.exists(req_file):
            logger.error(
                f"{e}\nAll packages listed in '{req_file}' must be installed."
                f" You can automate the installation using the '{install_file} --recipe={recipe_name}' script."
                " For more information, please refer to the Nsight Systems User Guide."
            )
        else:
            raise


def get_recipe_class_from_name(recipe_name):
    # Search for the recipe in the following order:
    # 1. Directory set by the environment variable NSYS_RECIPE_PATH.
    # 2. Current directory.
    # 3. Default 'nsys_recipe/recipes' directory.
    recipe_search_paths = ["", nsys_constants.NSYS_RECIPE_RECIPES_PATH]

    recipe_path_env_var = os.getenv("NSYS_RECIPE_PATH")
    if recipe_path_env_var is not None:
        recipe_search_paths.insert(0, recipe_path_env_var)

    for search_path in recipe_search_paths:
        metadata = get_metadata_dict(search_path, recipe_name)
        if metadata is None:
            continue

        metadata_path = get_metadata_path(search_path, recipe_name)
        module_name = metadata.get("module_name")
        if module_name is None:
            logger.error(
                f"The '{metadata_path}' file does not contain the 'module_name' field."
                " Please set the 'module_name' field to the name of the module that contains the Recipe class."
            )
            return None

        module = get_recipe_module(search_path, recipe_name, module_name, metadata_path)
        if module is None:
            return None
        return get_recipe_class_from_module(
            module, module_name, metadata_path, metadata.get("class_name")
        )

    logger.error("Unknown recipe.")
    return None
