# SPDX-License-Identifier: GPL-2.0-or-later
#
# Copyright (c) 2015-2016 Western Digital Corporation or its affiliates.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA  02110-1301, USA.
#
#   Author: Chaitanya Kulkarni <chaitanya.kulkarni@hgst.com>
#
""" Base class for all the testcases
"""

import json
import mmap
import os
import re
import shutil
import stat
import subprocess
import sys
import time
import unittest

from nvme_test_logger import TestNVMeLogger


class TestNVMe(unittest.TestCase):

    """
    Represents a testcase, each testcase should inherit this
    class or appropriate subclass which is a child of this class.

    Common utility functions used in various testcases.

        - Attributes:
            - ctrl : NVMe Controller.
            - ns1 : default namespace.
            - default_nsid : default namespace id.
            - config_file : configuration file.
            - clear_log_dir : default log directory.
    """

    def setUp(self):
        """ Pre Section for TestNVMe. """
        # common code used in various testcases.
        self.ctrl = "XXX"
        self.ns1 = "XXX"
        self.test_log_dir = "XXX"
        self.do_validate_pci_device = True
        self.default_nsid = 0x1
        self.config_file = 'tests/config.json'

        self.load_config()
        if self.do_validate_pci_device:
            self.validate_pci_device()

    def tearDown(self):
        """ Post Section for TestNVMe. """
        if self.clear_log_dir is True:
            shutil.rmtree(self.log_dir, ignore_errors=True)
        self.create_and_attach_default_ns()

    def create_and_attach_default_ns(self):
        """ Creates a default namespace with the full capacity of the ctrls NVM
            - Args:
                - None
            - Returns:
                - None
        """
        self.dps = 0
        self.flbas = 0
        (ds, ms) = self.get_lba_format_size()
        ncap = int(self.get_ncap() / (ds+ms))
        self.nsze = ncap
        self.ncap = ncap
        self.ctrl_id = self.get_ctrl_id()
        self.delete_all_ns()
        err = self.create_and_validate_ns(self.default_nsid,
                                          self.nsze,
                                          self.ncap,
                                          self.flbas,
                                          self.dps)
        self.assertEqual(err, 0)
        self.assertEqual(self.attach_ns(self.ctrl_id, self.default_nsid), 0)

    def validate_pci_device(self):
        """ Validate underlying device belongs to pci subsystem.
            - Args:
                - None
            - Returns:
                - None
        """
        x1, x2, dev = self.ctrl.split('/')
        cmd = cmd = "find /sys/devices -name \\*" + dev + " | grep -i pci"
        err = subprocess.call(cmd, shell=True)
        self.assertEqual(err, 0, "ERROR : Only NVMe PCI subsystem is supported")

    def load_config(self):
        """ Load Basic test configuration.
            - Args:
                - None
            - Returns:
                - None
        """
        with open(self.config_file) as data_file:
            config = json.load(data_file)
            self.ctrl = config['controller']
            self.ns1 = config['ns1']
            self.log_dir = config['log_dir']
            self.do_validate_pci_device = config.get('do_validate_pci_device', self.do_validate_pci_device)
            self.clear_log_dir = False

            if self.clear_log_dir is True:
                shutil.rmtree(self.log_dir, ignore_errors=True)

            if not os.path.exists(self.log_dir):
                os.makedirs(self.log_dir)

    def setup_log_dir(self, test_name):
        """ Set up the log directory for a testcase
            Args:
              - test_name : name of the testcase.
            Returns:
              - None
        """
        self.test_log_dir = self.log_dir + "/" + test_name
        if not os.path.exists(self.test_log_dir):
            os.makedirs(self.test_log_dir)
        sys.stdout = TestNVMeLogger(self.test_log_dir + "/" + "stdout.log")
        sys.stderr = TestNVMeLogger(self.test_log_dir + "/" + "stderr.log")

    def exec_cmd(self, cmd):
        """ Wrapper for executing a shell command and return the result. """
        proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE,
                                encoding='utf-8')
        return proc.wait()

    def nvme_reset_ctrl(self):
        """ Wrapper for nvme reset command.
            - Args:
                - None:
            - Returns:
                - None
        """
        nvme_reset_cmd = "nvme reset " + self.ctrl
        err = subprocess.call(nvme_reset_cmd,
                              shell=True,
                              stdout=subprocess.PIPE,
                              encoding='utf-8')
        self.assertEqual(err, 0, "ERROR : nvme reset failed")
        time.sleep(5)
        rescan_cmd = "echo 1 > /sys/bus/pci/rescan"
        proc = subprocess.Popen(rescan_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                encoding='utf-8')
        time.sleep(5)
        self.assertEqual(proc.wait(), 0, "ERROR : pci rescan failed")

    def get_ctrl_id(self):
        """ Wrapper for extracting the first controller id.
            - Args:
                - None
            - Returns:
                - controller id.
        """
        get_ctrl_id = f"nvme list-ctrl {self.ctrl} --output-format=json"
        proc = subprocess.Popen(get_ctrl_id,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : nvme list-ctrl failed")
        json_output = json.loads(proc.stdout.read())
        self.assertTrue(len(json_output['ctrl_list']) > 0,
                        "ERROR : nvme list-ctrl could not find ctrl")
        return str(json_output['ctrl_list'][0]['ctrl_id'])

    def get_ns_list(self):
        """ Wrapper for extracting the namespace list.
            - Args:
                - None
            - Returns:
                - List of the namespaces.
        """
        ns_list = []
        ns_list_cmd = "nvme list-ns " + self.ctrl
        proc = subprocess.Popen(ns_list_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        self.assertEqual(proc.wait(), 0, "ERROR : nvme list namespace failed")
        for line in proc.stdout:
            ns_list.append(line.split('x')[-1])

        return ns_list

    def get_max_ns(self):
        """ Wrapper for extracting maximum number of namespaces supported.
            - Args:
                - None
            - Returns:
                - maximum number of namespaces supported.
        """
        pattern = re.compile("^nn[ ]+: [0-9]", re.IGNORECASE)
        max_ns = -1
        max_ns_cmd = "nvme id-ctrl " + self.ctrl
        proc = subprocess.Popen(max_ns_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : reading maximum namespace count failed")

        for line in proc.stdout:
            if pattern.match(line):
                max_ns = line.split(":")[1].strip()
                break
        print(max_ns)
        return int(max_ns)

    def get_lba_status_supported(self):
        """ Check if 'Get LBA Status' command is supported by the device
            - Args:
                - None
            - Returns:
                - True if 'Get LBA Status' command is supported, otherwise False
        """
        return int(self.get_id_ctrl_field_value("oacs"), 16) & (1 << 9)

    def get_lba_format_size(self):
        """ Wrapper for extracting lba format size of the given flbas
            - Args:
                - None
            - Returns:
                - lba format size as a tuple of (data_size, metadata_size) in bytes.
        """
        nvme_id_ns_cmd = f"nvme id-ns {self.ns1} --output-format=json"
        proc = subprocess.Popen(nvme_id_ns_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : reading id-ns")
        json_output = json.loads(proc.stdout.read())
        self.assertTrue(len(json_output['lbafs']) > self.flbas,
                        "Error : could not match the given flbas to an existing lbaf")
        lbaf_json = json_output['lbafs'][int(self.flbas)]
        ms_expo = int(lbaf_json['ms'])
        ds_expo = int(lbaf_json['ds'])
        ds = 0
        ms = 0
        if ds_expo > 0:
            ds = (1 << ds_expo)
        if ms_expo > 0:
            ms = (1 << ms_expo)
        return (ds, ms)

    def get_ncap(self):
        """ Wrapper for extracting capacity.
            - Args:
                - None
            - Returns:
                - maximum number of namespaces supported.
        """
        pattern = re.compile("^tnvmcap[ ]+: [0-9]", re.IGNORECASE)
        ncap = -1
        ncap_cmd = "nvme id-ctrl " + self.ctrl
        proc = subprocess.Popen(ncap_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : reading nvm capacity failed")

        for line in proc.stdout:
            if pattern.match(line):
                ncap = line.split(":")[1].strip()
                break
        print(ncap)
        return int(ncap)

    def get_id_ctrl_field_value(self, field):
        """ Wrapper for extracting id-ctrl field values
            - Args:
                - None
            - Returns:
                - Filed value of the given field
        """
        id_ctrl_cmd = f"nvme id-ctrl {self.ctrl} --output-format=json"
        proc = subprocess.Popen(id_ctrl_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : reading id-ctrl failed")
        json_output = json.loads(proc.stdout.read())
        self.assertTrue(field in json_output,
                        f"ERROR : reading field '{field}' failed")
        return str(json_output[field])

    def get_ocfs(self):
        """ Wrapper for extracting optional copy formats supported
            - Args:
                - None
            - Returns:
                - Optional Copy Formats Supported
        """
        return int(self.get_id_ctrl_field_value("ocfs"), 16)

    def get_format(self):
        """ Wrapper for extracting format.
            - Args:
                - None
            - Returns:
                - maximum format of namespace.
        """
        # defaulting to 4K
        nvm_format = 4096
        nvm_format_cmd = "nvme id-ns " + self.ctrl + " -n1"
        proc = subprocess.Popen(nvm_format_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : reading nvm capacity failed")

        for line in proc.stdout:
            if "in use" in line:
                nvm_format = 2 ** int(line.split(":")[3].split()[0])
        print(nvm_format)
        return int(nvm_format)

    def delete_all_ns(self):
        """ Wrapper for deleting all the namespaces.
            - Args:
                - None
            - Returns:
                - None
        """
        delete_ns_cmd = "nvme delete-ns " + self.ctrl + " -n 0xFFFFFFFF"
        self.assertEqual(self.exec_cmd(delete_ns_cmd), 0)
        list_ns_cmd = "nvme list-ns " + self.ctrl + " --all | wc -l"
        proc = subprocess.Popen(list_ns_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        output = proc.stdout.read().strip()
        self.assertEqual(output, '0', "ERROR : deleting all namespace failed")

    def create_ns(self, nsze, ncap, flbas, dps):
        """ Wrapper for creating a namespace.
            - Args:
                - nsze : new namespace size.
                - ncap : new namespace capacity.
                - flbas : new namespace format.
                - dps : new namespace data protection information.
            - Returns:
                - return code of the nvme create namespace command.
        """
        create_ns_cmd = "nvme create-ns " + self.ctrl + " --nsze=" + \
                        str(nsze) + " --ncap=" + str(ncap) + \
                        " --flbas=" + str(flbas) + " --dps=" + str(dps)
        return self.exec_cmd(create_ns_cmd)

    def create_and_validate_ns(self, nsid, nsze, ncap, flbas, dps):
        """ Wrapper for creating and validating a namespace.
            - Args:
                - nsid : new namespace id.
                - nsze : new namespace size.
                - ncap : new namespace capacity.
                - flbas : new namespace format.
                - dps : new namespace data protection information.
            - Returns:
                - return 0 on success, error code on failure.
        """
        err = self.create_ns(nsze, ncap, flbas, dps)
        if err == 0:
            time.sleep(2)
            id_ns_cmd = "nvme id-ns " + self.ctrl + " -n " + str(nsid)
            err = subprocess.call(id_ns_cmd,
                                  shell=True,
                                  stdout=subprocess.PIPE,
                                  encoding='utf-8')
        return err

    def attach_ns(self, ctrl_id, ns_id):
        """ Wrapper for attaching the namespace.
            - Args:
                - ctrl_id : controller id to which namespace to be attached.
                - nsid : new namespace id.
            - Returns:
                - 0 on success, error code on failure.
        """
        attach_ns_cmd = "nvme attach-ns " + self.ctrl + \
                        " --namespace-id=" + str(ns_id) + \
                        " --controllers=" + ctrl_id
        err = subprocess.call(attach_ns_cmd,
                              shell=True,
                              stdout=subprocess.PIPE,
                              encoding='utf-8')
        time.sleep(5)
        if err == 0:
            # enumerate new namespace block device
            self.nvme_reset_ctrl()
            time.sleep(5)
            # check if new namespace block device exists
            err = 0 if stat.S_ISBLK(os.stat(self.ns1).st_mode) else 1
        return err

    def detach_ns(self, ctrl_id, nsid):
        """ Wrapper for detaching the namespace.
            - Args:
                - ctrl_id : controller id to which namespace to be attached.
                - nsid : new namespace id.
            - Returns:
                - 0 on success, error code on failure.
        """
        detach_ns_cmd = "nvme detach-ns " + self.ctrl + \
                        " --namespace-id=" + str(nsid) + \
                        " --controllers=" + ctrl_id
        return subprocess.call(detach_ns_cmd,
                               shell=True,
                               stdout=subprocess.PIPE,
                               encoding='utf-8')

    def delete_and_validate_ns(self, nsid):
        """ Wrapper for deleting and validating that namespace is deleted.
            - Args:
                - nsid : new namespace id.
            - Returns:
                - 0 on success, 1 on failure.
        """
        # delete the namespace
        delete_ns_cmd = "nvme delete-ns " + self.ctrl + " -n " + str(nsid)
        err = subprocess.call(delete_ns_cmd,
                              shell=True,
                              stdout=subprocess.PIPE,
                              encoding='utf-8')
        self.assertEqual(err, 0, "ERROR : delete namespace failed")
        return err

    def get_smart_log(self, nsid):
        """ Wrapper for nvme smart-log command.
            - Args:
                - nsid : namespace id to get smart log from.
            - Returns:
                - 0 on success, error code on failure.
        """
        smart_log_cmd = "nvme smart-log " + self.ctrl + " -n " + str(nsid)
        print(smart_log_cmd)
        proc = subprocess.Popen(smart_log_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : nvme smart log failed")
        smart_log_output = proc.communicate()[0]
        print(f"{smart_log_output}")
        return err

    def get_id_ctrl(self, vendor=False):
        """ Wrapper for nvme id-ctrl command.
            - Args:
              - None
            - Returns:
              - 0 on success, error code on failure.
        """
        if not vendor:
            id_ctrl_cmd = "nvme id-ctrl " + self.ctrl
        else:
            id_ctrl_cmd = "nvme id-ctrl --vendor-specific " + self.ctrl
        print(id_ctrl_cmd)
        proc = subprocess.Popen(id_ctrl_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : nvme id controller failed")
        return err

    def get_error_log(self):
        """ Wrapper for nvme error-log command.
            - Args:
                - None
            - Returns:
                - 0 on success, error code on failure.
        """
        pattern = re.compile(r"^ Entry\[[ ]*[0-9]+\]")
        error_log_cmd = "nvme error-log " + self.ctrl
        proc = subprocess.Popen(error_log_cmd,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : nvme error log failed")
        line = proc.stdout.readline()
        err_log_entry_count = int(line.split(" ")[5].strip().split(":")[1])
        entry_count = 0
        for line in proc.stdout:
            if pattern.match(line):
                entry_count += 1

        return 0 if err_log_entry_count == entry_count else 1

    def run_ns_io(self, nsid, lbads, count=10):
        """ Wrapper to run ios on namespace under test.
            - Args:
                - lbads : LBA Data size supported in power of 2 format.
            - Returns:
                - None
        """
        block_size = mmap.PAGESIZE if int(lbads) < 9 else 2 ** int(lbads)
        ns_path = self.ctrl + "n" + str(nsid)
        io_cmd = "dd if=" + ns_path + " of=/dev/null" + " bs=" + \
                 str(block_size) + " count=" + str(count) + " > /dev/null 2>&1"
        print(io_cmd)
        run_io = subprocess.Popen(io_cmd, shell=True, stdout=subprocess.PIPE,
                                  encoding='utf-8')
        run_io_result = run_io.communicate()[1]
        self.assertEqual(run_io_result, None)
        io_cmd = "dd if=/dev/zero of=" + ns_path + " bs=" + \
                 str(block_size) + " count=" + str(count) + " > /dev/null 2>&1"
        print(io_cmd)
        run_io = subprocess.Popen(io_cmd, shell=True, stdout=subprocess.PIPE,
                                  encoding='utf-8')
        run_io_result = run_io.communicate()[1]
        self.assertEqual(run_io_result, None)

    def supp_check_id_ctrl(self, key):
        """ Wrapper for support check.
            - Args:
                - key : search key.
            - Returns:
                - value for key requested.
        """
        id_ctrl = "nvme id-ctrl " + self.ctrl
        print("\n" + id_ctrl)
        proc = subprocess.Popen(id_ctrl,
                                shell=True,
                                stdout=subprocess.PIPE,
                                encoding='utf-8')
        err = proc.wait()
        self.assertEqual(err, 0, "ERROR : nvme Identify controller Data \
                         structure failed")
        for line in proc.stdout:
            if key in line:
                key = line.replace(",", "", 1)
        print(key)
        val = (key.split(':'))[1].strip()
        return int(val, 16)