# Copyright (c) 2022, Dell Inc. or its subsidiaries.  All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# See the LICENSE file for details.
#
# This file is part of NVMe STorage Appliance Services (nvme-stas).
#
# Authors: Martin Belanger <Martin.Belanger@dell.com>

'''A collection of IP address and network interface utilities'''

import socket
import logging
import ipaddress
from staslib import conf

RTM_NEWADDR = 20
RTM_GETADDR = 22
NLM_F_REQUEST = 0x01
NLM_F_ROOT = 0x100
NLMSG_DONE = 3
IFLA_ADDRESS = 1
NLMSGHDR_SZ = 16
IFADDRMSG_SZ = 8
RTATTR_SZ = 4

# Netlink request (Get address command)
GETADDRCMD = (
    # BEGIN: struct nlmsghdr
    b'\0' * 4  # nlmsg_len (placeholder - actual length calculated below)
    + (RTM_GETADDR).to_bytes(2, byteorder='little', signed=False)  # nlmsg_type
    + (NLM_F_REQUEST | NLM_F_ROOT).to_bytes(2, byteorder='little', signed=False)  # nlmsg_flags
    + b'\0' * 2  # nlmsg_seq
    + b'\0' * 2  # nlmsg_pid
    # END: struct nlmsghdr
    + b'\0' * 8  # struct ifaddrmsg
)
GETADDRCMD = len(GETADDRCMD).to_bytes(4, byteorder='little') + GETADDRCMD[4:]  # nlmsg_len


# ******************************************************************************
def get_ipaddress_obj(ipaddr):
    '''@brief Return a IPv4Address or IPv6Address depending on whether @ipaddr
    is a valid IPv4 or IPv6 address. Return None otherwise.'''
    try:
        ip = ipaddress.ip_address(ipaddr)
    except ValueError:
        return None

    return ip


# ******************************************************************************
def _data_matches_ip(data_family, data, ip):
    if data_family == socket.AF_INET:
        try:
            other_ip = ipaddress.IPv4Address(data)
        except ValueError:
            return False
        if ip.version == 6:
            ip = ip.ipv4_mapped
    elif data_family == socket.AF_INET6:
        try:
            other_ip = ipaddress.IPv6Address(data)
        except ValueError:
            return False
        if ip.version == 4:
            other_ip = other_ip.ipv4_mapped
    else:
        return False

    return other_ip == ip


# ******************************************************************************
def iface_of(src_addr):
    '''@brief Find the interface that has src_addr as one of its assigned IP addresses.
    @param src_addr: The IP address to match
    @type src_addr: Instance of ipaddress.IPv4Address or ipaddress.IPv6Address
    '''
    with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW) as sock:
        sock.sendall(GETADDRCMD)
        nlmsg = sock.recv(8192)
        nlmsg_idx = 0
        while True:
            if nlmsg_idx >= len(nlmsg):
                nlmsg += sock.recv(8192)

            nlmsg_type = int.from_bytes(nlmsg[nlmsg_idx + 4 : nlmsg_idx + 6], byteorder='little', signed=False)
            if nlmsg_type == NLMSG_DONE:
                break

            if nlmsg_type != RTM_NEWADDR:
                break

            nlmsg_len = int.from_bytes(nlmsg[nlmsg_idx : nlmsg_idx + 4], byteorder='little', signed=False)
            if nlmsg_len % 4:  # Is msg length not a multiple of 4?
                break

            ifaddrmsg_indx = nlmsg_idx + NLMSGHDR_SZ
            ifa_family = nlmsg[ifaddrmsg_indx]
            ifa_index = int.from_bytes(nlmsg[ifaddrmsg_indx + 4 : ifaddrmsg_indx + 8], byteorder='little', signed=False)

            rtattr_indx = ifaddrmsg_indx + IFADDRMSG_SZ
            while rtattr_indx < (nlmsg_idx + nlmsg_len):
                rta_len = int.from_bytes(nlmsg[rtattr_indx : rtattr_indx + 2], byteorder='little', signed=False)
                rta_type = int.from_bytes(nlmsg[rtattr_indx + 2 : rtattr_indx + 4], byteorder='little', signed=False)
                if rta_type == IFLA_ADDRESS:
                    data = nlmsg[rtattr_indx + RTATTR_SZ : rtattr_indx + rta_len]
                    if _data_matches_ip(ifa_family, data, src_addr):
                        return socket.if_indextoname(ifa_index)

                rta_len = (rta_len + 3) & ~3  # Round up to multiple of 4
                rtattr_indx += rta_len  # Move to next rtattr

            nlmsg_idx += nlmsg_len  # Move to next Netlink message

    return ''


# ******************************************************************************
def get_interface(src_addr):
    '''Get interface for given source address
    @param src_addr: The source address
    @type src_addr: str
    '''
    if not src_addr:
        return ''

    src_addr = src_addr.split('%')[0]  # remove scope-id (if any)
    src_addr = get_ipaddress_obj(src_addr)
    return '' if src_addr is None else iface_of(src_addr)


# ******************************************************************************
def remove_invalid_addresses(controllers: list):
    '''@brief Remove controllers with invalid addresses from the list of controllers.
    @param controllers: List of TIDs
    '''
    service_conf = conf.SvcConf()
    valid_controllers = list()
    for controller in controllers:
        if controller.transport in ('tcp', 'rdma'):
            # Let's make sure that traddr is
            # syntactically a valid IPv4 or IPv6 address.
            ip = get_ipaddress_obj(controller.traddr)
            if ip is None:
                logging.warning('%s IP address is not valid', controller)
                continue

            # Let's make sure the address family is enabled.
            if ip.version not in service_conf.ip_family:
                logging.debug(
                    '%s ignored because IPv%s is disabled in %s',
                    controller,
                    ip.version,
                    service_conf.conf_file,
                )
                continue

            valid_controllers.append(controller)

        elif controller.transport in ('fc', 'loop'):
            # At some point, need to validate FC addresses as well...
            valid_controllers.append(controller)

        else:
            logging.warning('Invalid transport %s', controller.transport)

    return valid_controllers