#!/bin/bash

# During the PAM session phase, remove the successfully SSH-authenticated
# host's source address(-prefix) from the configured rate-limiting nft(8) set.

# This script is part of an article originally published on at
# https://johannes.truschnigg.info/writing/2025-02-simple_effective_ssh_ratelimiting_pam_nftables/

# A copy of it is to be installed like so:
# -r-xr----- 1 root staff 1.5K Feb 13 08:00 /usr/local/lib/pam_clear_nft_ratelimits
# 
# ... and integrated it into your OpenSSH ssd PAM stack like so:
# session    optional     pam_exec.so /usr/local/lib/pam_clear_nft_ratelimits

# Copyright (C) 2025 Johannes Truschnigg  <johannes@truschnigg.info>
# Licensed under the terms of the GNU GPLv3 or later


# Configuration Section BEGIN
# These values have to correspond to your nftables ruleset choices!
nft_breadcrumbs_ipv4='inet firewall ssh_ratelimit_v4'
nft_breadcrumbs_ipv6='inet firewall ssh_ratelimit_v6'
nft_ip6_use_prefix=true
# Configuration Section END

umask 0077
export TZ=UTC
export LC_ALL=C
: DEBUG settings processed

if ! [[ ${PAM_TYPE} = 'open_session' ]]
then
    : Not executing in PAM session phase - exiting early
    exit 0
fi
: DEBUG PAM check passed

if ! type -p nft >/dev/null
then
    printf 'FATAL: Failed to find `nft` executable in PATH: %q\n' "${PATH}" >&2
    exit 1
fi
: DEBUG nft found in PATH

if ! [[ ${SSH_CONNECTION} ]]
then
    printf 'FATAL: Failed to find SSH_CONNECTION in environment\n' >&2
    exit 1
fi
: DEBUG sshd environment check passed


# Try to validate an IPv4 or IPv4 address as communicated by OpenSSH sshd's
# SSH_CONNECTION env variable.
#
# This function is not perfect, but good enough for the intended purpose.
parse_addr() {
    : DEBUG parsing input address: "${1}"
    local addr_in
    local addr_part
    local addr_arr

    addr_in="${1,,}"

    case "${addr_in}" in
    *.*.*.*)
	: DEBUG assuming IPv4 quad-dotted
	IFS='.' read -r -a addr_arr <<<"${addr_in}"
	if [[ ${#addr_arr[@]} -ne 4 ]]
	then
	    printf 'invalid'
	    return 1
	fi
	for addr_part in "${addr_arr[@]}"
	do
	    case "${addr_part}" in
		[0-9]|[0-9][0-9]|[12][0-9][0-9])
		    if (( addr_part > 255 ))
		    then
			printf 'invalid'
			return 1
		    else
			: valid part: "${addr_part}"
		    fi
		;;
		*)
		    printf 'invalid'
		    return 1
		;;
	    esac
	done
	printf 'ipv4'
	return 0
    ;;
    *:*:*)
	: DEBUG assuming IPv6
	IFS=':' read -r -a addr_arr <<<"${addr_in}"
	if [[ ${#addr_arr[@]} -gt 8 || ${#addr_arr[@]} -lt 3 ]]
	then
	    printf 'invalid'
	    return 1
	fi
	if [[ ${addr_in} = *:::* || ${addr_in} = *::*::* ]]
	then
	    printf 'invalid'
	    return 1
	fi
	for addr_part in "${addr_arr[@]}"
	do
	    case "${addr_part}" in
		[0-9a-f][0-9a-f][0-9a-f][0-9a-f]|[0-9a-f][0-9a-f][0-9a-f]|[0-9a-f][0-9a-f]|[0-9a-f]|'')
		    : valid part: "${addr_part}"
		;;
		*)
		    printf 'invalid'
		    return 1
		;;
	    esac
	done
	printf 'ipv6'
	return 0
    ;;
    *)
	printf 'invalid'
	return 1
    ;;
    esac
}


raddr="${SSH_CONNECTION%% *}"
nft_arg=''
set -u
: DEBUG Processing input...

addr_fam=$(parse_addr "${raddr}")
: DEBUG finished parsing address
if [[ ${addr_fam} = invalid ]]
then
    printf 'Failed to parse as IP address: %s\n' "${raddr}"
    exit 1
elif [[ ${addr_fam} = ipv4 ]]
then
    : DEBUG IPv4 address determined
    printf -v nft_arg 'delete element %s { %s }' "${nft_breadcrumbs_ipv4}" "${raddr}"
elif [[ ${addr_fam} = ipv6 ]]
then
    if [[ ${nft_ip6_use_prefix} = true ]]
    then
	: DEBUG IPv6 address determined, use_prefix in effect
        printf -v nft_arg 'delete element %s { %s & ffff:ffff:ffff:ffff:: }' "${nft_breadcrumbs_ipv6}" "${raddr}"
    else
	: DEBUG IPv6 address determined
        printf -v nft_arg 'delete element %s { %s }' "${nft_breadcrumbs_ipv6}" "${raddr}"
    fi
fi

: DEBUG about to exec nft...
exec nft "${nft_arg}"
