# -*- coding: utf-8 -*-

# Copyright (c) 2017 Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

import re
from copy import deepcopy

try:
    import botocore
except ImportError:
    pass  # Modules are responsible for handling this.

from ansible.module_utils._text import to_native
from ansible.module_utils.common.dict_transformations import camel_dict_to_snake_dict

from .arn import parse_aws_arn
from .arn import validate_aws_arn
from .botocore import is_boto3_error_code
from .errors import AWSErrorHandler
from .exceptions import AnsibleAWSError
from .retries import AWSRetry
from .tagging import ansible_dict_to_boto3_tag_list
from .tagging import boto3_tag_list_to_ansible_dict


class AnsibleIAMError(AnsibleAWSError):
    pass


class IAMErrorHandler(AWSErrorHandler):
    _CUSTOM_EXCEPTION = AnsibleIAMError

    @classmethod
    def _is_missing(cls):
        return is_boto3_error_code("NoSuchEntity")


@AWSRetry.jittered_backoff()
def _tag_iam_instance_profile(client, **kwargs):
    client.tag_instance_profile(**kwargs)


@AWSRetry.jittered_backoff()
def _untag_iam_instance_profile(client, **kwargs):
    client.untag_instance_profile(**kwargs)


@AWSRetry.jittered_backoff()
def _get_iam_instance_profiles(client, **kwargs):
    return client.get_instance_profile(**kwargs)["InstanceProfile"]


@AWSRetry.jittered_backoff()
def _list_iam_instance_profiles(client, **kwargs):
    paginator = client.get_paginator("list_instance_profiles")
    return paginator.paginate(**kwargs).build_full_result()["InstanceProfiles"]


@AWSRetry.jittered_backoff()
def _list_iam_instance_profiles_for_role(client, **kwargs):
    paginator = client.get_paginator("list_instance_profiles_for_role")
    return paginator.paginate(**kwargs).build_full_result()["InstanceProfiles"]


@AWSRetry.jittered_backoff()
def _create_instance_profile(client, **kwargs):
    return client.create_instance_profile(**kwargs)


@AWSRetry.jittered_backoff()
def _delete_instance_profile(client, **kwargs):
    client.delete_instance_profile(**kwargs)


@AWSRetry.jittered_backoff()
def _add_role_to_instance_profile(client, **kwargs):
    client.add_role_to_instance_profile(**kwargs)


@AWSRetry.jittered_backoff()
def _remove_role_from_instance_profile(client, **kwargs):
    client.remove_role_from_instance_profile(**kwargs)


@AWSRetry.jittered_backoff()
def _list_managed_policies(client, **kwargs):
    paginator = client.get_paginator("list_policies")
    return paginator.paginate(**kwargs).build_full_result()


@IAMErrorHandler.common_error_handler("list all managed policies")
def list_managed_policies(client):
    return _list_managed_policies(client)["Policies"]


def convert_managed_policy_names_to_arns(client, policy_names):
    if all(validate_aws_arn(policy, service="iam") for policy in policy_names if policy is not None):
        return policy_names
    allpolicies = {}
    policies = list_managed_policies(client)

    for policy in policies:
        allpolicies[policy["PolicyName"]] = policy["Arn"]
        allpolicies[policy["Arn"]] = policy["Arn"]
    try:
        return [allpolicies[policy] for policy in policy_names if policy is not None]
    except KeyError as e:
        raise AnsibleIAMError(message="Failed to find policy by name:" + str(e), exception=e) from e


def get_aws_account_id(module):
    """Given an AnsibleAWSModule instance, get the active AWS account ID"""

    return get_aws_account_info(module)[0]


def get_aws_account_info(module):
    """Given an AnsibleAWSModule instance, return the account information
    (account id and partition) we are currently working on

    get_account_info tries too find out the account that we are working
    on.  It's not guaranteed that this will be easy so we try in
    several different ways.  Giving either IAM or STS privileges to
    the account should be enough to permit this.

    Tries:
    - sts:GetCallerIdentity
    - iam:GetUser
    - sts:DecodeAuthorizationMessage
    """
    account_id = None
    partition = None
    try:
        sts_client = module.client("sts", retry_decorator=AWSRetry.jittered_backoff())
        caller_id = sts_client.get_caller_identity(aws_retry=True)
        account_id = caller_id.get("Account")
        partition = caller_id.get("Arn").split(":")[1]
    except (botocore.exceptions.BotoCoreError, botocore.exceptions.ClientError):
        try:
            iam_client = module.client("iam", retry_decorator=AWSRetry.jittered_backoff())
            _arn, partition, _service, _reg, account_id, _resource = iam_client.get_user(aws_retry=True)["User"][
                "Arn"
            ].split(":")
        except is_boto3_error_code("AccessDenied") as e:
            try:
                except_msg = to_native(e.message)
            except AttributeError:
                except_msg = to_native(e)
            result = parse_aws_arn(except_msg)
            if result is None or result["service"] != "iam":
                module.fail_json_aws(
                    e,
                    msg="Failed to get AWS account information, Try allowing sts:GetCallerIdentity or iam:GetUser permissions.",
                )
            account_id = result.get("account_id")
            partition = result.get("partition")
        except (  # pylint: disable=duplicate-except
            botocore.exceptions.BotoCoreError,
            botocore.exceptions.ClientError,
        ) as e:
            module.fail_json_aws(
                e,
                msg="Failed to get AWS account information, Try allowing sts:GetCallerIdentity or iam:GetUser permissions.",
            )

    if account_id is None or partition is None:
        module.fail_json(
            msg="Failed to get AWS account information, Try allowing sts:GetCallerIdentity or iam:GetUser permissions.",
        )

    return (to_native(account_id), to_native(partition))


@IAMErrorHandler.common_error_handler("create instance profile")
def create_iam_instance_profile(client, name, path, tags):
    boto3_tags = ansible_dict_to_boto3_tag_list(tags or {})
    path = path or "/"
    result = _create_instance_profile(client, InstanceProfileName=name, Path=path, Tags=boto3_tags)
    return result["InstanceProfile"]


@IAMErrorHandler.deletion_error_handler("delete instance profile")
def delete_iam_instance_profile(client, name):
    _delete_instance_profile(client, InstanceProfileName=name)
    # Error Handler will return False if the resource didn't exist
    return True


@IAMErrorHandler.common_error_handler("add role to instance profile")
def add_role_to_iam_instance_profile(client, profile_name, role_name):
    _add_role_to_instance_profile(client, InstanceProfileName=profile_name, RoleName=role_name)
    return True


@IAMErrorHandler.deletion_error_handler("remove role from instance profile")
def remove_role_from_iam_instance_profile(client, profile_name, role_name):
    _remove_role_from_instance_profile(client, InstanceProfileName=profile_name, RoleName=role_name)
    # Error Handler will return False if the resource didn't exist
    return True


@IAMErrorHandler.list_error_handler("list instance profiles", [])
def list_iam_instance_profiles(client, name=None, prefix=None, role=None):
    """
    Returns a list of IAM instance profiles in boto3 format.
    Profiles need to be converted to Ansible format using normalize_iam_instance_profile before being displayed.

    See also: normalize_iam_instance_profile
    """
    if role:
        return _list_iam_instance_profiles_for_role(client, RoleName=role)
    if name:
        # Unlike the others this returns a single result, make this a list with 1 element.
        return [_get_iam_instance_profiles(client, InstanceProfileName=name)]
    if prefix:
        return _list_iam_instance_profiles(client, PathPrefix=prefix)
    return _list_iam_instance_profiles(client)


def normalize_iam_instance_profile(profile):
    """
    Converts a boto3 format IAM instance profile into "Ansible" format
    """

    new_profile = camel_dict_to_snake_dict(deepcopy(profile))
    if profile.get("Roles"):
        new_profile["roles"] = [normalize_iam_role(role) for role in profile.get("Roles")]
    if profile.get("Tags"):
        new_profile["tags"] = boto3_tag_list_to_ansible_dict(profile.get("Tags"))
    else:
        new_profile["tags"] = {}
    new_profile["original"] = profile
    return new_profile


def normalize_iam_role(role):
    """
    Converts a boto3 format IAM instance role into "Ansible" format
    """

    new_role = camel_dict_to_snake_dict(deepcopy(role))
    if role.get("InstanceProfiles"):
        new_role["instance_profiles"] = [
            normalize_iam_instance_profile(profile) for profile in role.get("InstanceProfiles")
        ]
    if role.get("AssumeRolePolicyDocument"):
        new_role["assume_role_policy_document"] = role.get("AssumeRolePolicyDocument")
    if role.get("Tags"):
        new_role["tags"] = boto3_tag_list_to_ansible_dict(role.get("Tags"))
    else:
        new_role["tags"] = {}
    new_role["original"] = role
    return new_role


@IAMErrorHandler.common_error_handler("tag instance profile")
def tag_iam_instance_profile(client, name, tags):
    if not tags:
        return
    boto3_tags = ansible_dict_to_boto3_tag_list(tags or {})
    result = _tag_iam_instance_profile(client, InstanceProfileName=name, Tags=boto3_tags)


@IAMErrorHandler.common_error_handler("untag instance profile")
def untag_iam_instance_profile(client, name, tags):
    if not tags:
        return
    result = _untag_iam_instance_profile(client, InstanceProfileName=name, TagKeys=tags)


def _validate_iam_name(resource_type, name=None):
    if name is None:
        return None
    LENGTHS = {"role": 64, "user": 64}
    regex = r"[\w+=,.@-]+"
    max_length = LENGTHS.get(resource_type, 128)
    if len(name) > max_length:
        return f"Length of {resource_type} name may not exceed {max_length}"
    if not re.fullmatch(regex, name):
        return f"{resource_type} name must match pattern {regex}"
    return None


def _validate_iam_path(resource_type, path=None):
    if path is None:
        return None
    regex = r"\/([\w+=,.@-]+\/)*"
    max_length = 512
    if len(path) > max_length:
        return f"Length of {resource_type} path may not exceed {max_length}"
    if not path.endswith("/") or not path.startswith("/"):
        return f"{resource_type} path must begin and end with /"
    if not re.fullmatch(regex, path):
        return f"{resource_type} path must match pattern {regex}"
    return None


def validate_iam_identifiers(resource_type, name=None, path=None):
    name_problem = _validate_iam_name(resource_type, name)
    if name_problem:
        return name_problem
    path_problem = _validate_iam_path(resource_type, path)
    if path_problem:
        return path_problem

    return None
