#!/usr/bin/python3
# pylint: disable=invalid-name
# pylint: enable=invalid-name

"""Prometheus exporter for Foomuuri metrics."""

import argparse
import json
import pathlib
import re
import subprocess
import time
from collections import Counter

# pylint: disable=import-error
from prometheus_client import REGISTRY, start_http_server
from prometheus_client.core import CounterMetricFamily, GaugeMetricFamily
from prometheus_client.registry import Collector


class MonitorCollector(Collector):
    """Collect Foomuuri Monitor metrics."""

    # pylint: disable=too-few-public-methods

    def __init__(self, args):
        """Initialize class."""
        super().__init__()
        self._stat_file = pathlib.Path(args.statistics_file)

    def collect(self):
        """Return Foomuuri Monitor metrics."""
        try:
            stats = json.loads(self._stat_file.read_text(encoding='utf-8'))
        except (OSError, ValueError):
            return

        # Target/group state
        g = GaugeMetricFamily(
            'foomuuri_monitor_up',
            'Target connectivity status.',
            labels=['type', 'name'],
        )
        for name, value in stats.items():
            g.add_metric([value['type'], name], value['state'])
        yield g

        # Target packet loss
        g = GaugeMetricFamily(
            'foomuuri_monitor_packet_loss_ratio',
            'Average ping packet loss.',
            labels=['type', 'name'],
        )
        for name, value in stats.items():
            if not value['time']:
                continue
            loss = sum(1 for item in value['time'] if item is None)
            g.add_metric([value['type'], name], loss / len(value['time']))
        yield g

        # Target packet ping latency
        g = GaugeMetricFamily(
            'foomuuri_monitor_ping_seconds',
            'Average network round trip time.',
            labels=['type', 'name'],
        )
        for name, value in stats.items():
            times = [item / 1000 for item in value['time'] if item is not None]
            if not times:
                continue
            g.add_metric([value['type'], name], sum(times) / len(times))
        yield g


class RulesetCollector(Collector):
    """Collect ruleset metrics."""

    # pylint: disable=too-few-public-methods

    def __init__(self, args):
        """Initialize class."""
        super().__init__()
        self._set_include = re.compile(args.set_include)
        self._set_exclude = re.compile(args.set_exclude)
        self._counter_include = re.compile(args.counter_include)
        self._counter_exclude = re.compile(args.counter_exclude)

    def collect(self):
        """Return ruleset metrics."""
        try:
            nftdata = json.loads(
                subprocess.run(
                    ['nft', '--json', 'list', 'table', 'inet', 'foomuuri'],
                    stdout=subprocess.PIPE,
                    check=False,
                    encoding='utf-8',
                ).stdout
            )
        except (OSError, ValueError):
            return

        # Set size - merge IPv4 and IPv6 to single value
        sets = Counter()
        for item in nftdata.get('nftables', {}):
            if 'set' in item:
                name = item['set']['name'][:-2]  # Without "_4"
                if self._set_include.search(
                    name
                ) and not self._set_exclude.search(name):
                    sets.update({name: len(item['set'].get('elem', []))})

        g = GaugeMetricFamily(
            'foomuuri_set_elements',
            'Number of elements in set.',
            labels=['name'],
        )
        for name, value in sets.items():
            g.add_metric([name], value)
        yield g

        # Named counters
        counters = {}
        for item in nftdata.get('nftables', {}):
            if 'counter' in item:
                name = item['counter']['name']
                if self._counter_include.search(
                    name
                ) and not self._counter_exclude.search(name):
                    counters[name] = {
                        'bytes': item['counter']['bytes'],
                        'packets': item['counter']['packets'],
                    }

        g = CounterMetricFamily(
            'foomuuri_counter_bytes_total',
            'Counter bytes value.',
            labels=['name'],
        )
        for name, value in counters.items():
            g.add_metric([name], value['bytes'])
        yield g

        g = CounterMetricFamily(
            'foomuuri_counter_packets_total',
            'Counter packets value.',
            labels=['name'],
        )
        for name, value in counters.items():
            g.add_metric([name], value['packets'])
        yield g


def main():
    """Parse arguments and run."""
    # Command line parser
    parser = argparse.ArgumentParser(
        description='Export Foomuuri statistics to Prometheus.'
    )
    parser.add_argument(
        '--address', default='::', help='listen address (default: ::)'
    )
    parser.add_argument(
        '--port',
        type=int,
        default=11041,
        help='listen port number (default: 11041)',
    )
    parser.add_argument(
        '--tls-certificate',
        metavar='FILENAME',
        help='TLS certificate file name',
    )
    parser.add_argument(
        '--tls-key', metavar='FILENAME', help='TLS key file name'
    )
    parser.add_argument(
        '--no-monitor-statistics',
        action='store_true',
        help='do not export Foomuuri Monitor statistics',
    )
    parser.add_argument(
        '--no-ruleset-statistics',
        action='store_true',
        help='do not export ruleset statistics',
    )
    parser.add_argument(
        '--statistics-file',
        metavar='FILENAME',
        default='/var/lib/foomuuri/monitor.statistics',
        help='Foomuuri Monitor statistics file name',
    )
    parser.add_argument(
        '--set-include',
        metavar='REGEXP',
        default='.',  # default all
        help='set names to be included to ruleset size statistics',
    )
    parser.add_argument(
        '--set-exclude',
        metavar='REGEXP',
        default='$^',  # default nothing
        help='set names to be excluded from ruleset size statistics',
    )
    parser.add_argument(
        '--counter-include',
        metavar='REGEXP',
        default='.',
        help='counter names to be included to ruleset traffic statistics',
    )
    parser.add_argument(
        '--counter-exclude',
        metavar='REGEXP',
        default='$^',
        help='counter names to be excluded from ruleset traffic statistics',
    )
    args = parser.parse_args()

    # Register collectors
    if not args.no_monitor_statistics:
        REGISTRY.register(MonitorCollector(args))
    if not args.no_ruleset_statistics:
        REGISTRY.register(RulesetCollector(args))

    # Run exporter
    start_http_server(
        port=args.port,
        addr=args.address,
        certfile=args.tls_certificate,
        keyfile=args.tls_key,
    )
    while True:
        time.sleep(1)


if __name__ == '__main__':
    main()
