Move wg commands into config methods with validation

pull/1091/head
Super User 2026-01-30 17:56:30 -05:00
parent 1972b1e7a8
commit ef219e6a8e
5 changed files with 241 additions and 168 deletions

View File

@ -5,8 +5,7 @@ import subprocess
import uuid
from .Peer import Peer
from .Utilities import ValidateIPAddressesWithRange, ValidateDNSAddress, GenerateWireguardPublicKey, \
WgQuick, WgSetPeerAllowedIps
from .Utilities import ValidateIPAddressesWithRange, ValidateDNSAddress, GenerateWireguardPublicKey
class AmneziaWGPeer(Peer):
@ -59,23 +58,12 @@ class AmneziaWGPeer(Peer):
with open(uid, "w+") as f:
f.write(preshared_key)
psk_path = uid if pskExist else "/dev/null"
updateAllowedIp = WgSetPeerAllowedIps(
self.configuration.Protocol,
self.configuration.Name,
self.id,
allowed_ip,
psk_path
)
if pskExist: os.remove(uid)
if len(updateAllowedIp.decode().strip("\n")) != 0:
return False, "Update peer failed when updating Allowed IPs"
saveConfig = WgQuick(self.configuration.Protocol, "save", self.configuration.Name)
if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'):
return False, "Update peer failed when saving the configuration"
with self.configuration.engine.begin() as conn:
previous = conn.execute(
self.configuration.peersTable.select().where(
self.configuration.peersTable.c.id == self.id
)
).mappings().fetchone()
conn.execute(
self.configuration.peersTable.update().values({
"name": name,
@ -85,11 +73,48 @@ class AmneziaWGPeer(Peer):
"mtu": mtu,
"keepalive": keepalive,
"preshared_key": preshared_key,
"advanced_security": advanced_security
"advanced_security": advanced_security,
"allowed_ip": allowed_ip.replace(" ", "")
}).where(
self.configuration.peersTable.c.id == self.id
)
)
try:
updateAllowedIp = self.configuration._wg_set_peer_allowed_ips(
self.id,
psk_path
)
if pskExist: os.remove(uid)
if len(updateAllowedIp.decode().strip("\n")) != 0:
raise subprocess.CalledProcessError(1, "wg set peer")
saveConfig = self.configuration._wg_quick_save()
if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'):
raise subprocess.CalledProcessError(1, "wg-quick save")
except (subprocess.CalledProcessError, ValueError) as exc:
with self.configuration.engine.begin() as conn:
if previous:
conn.execute(
self.configuration.peersTable.update().values({
"name": previous["name"],
"private_key": previous["private_key"],
"DNS": previous["DNS"],
"endpoint_allowed_ip": previous["endpoint_allowed_ip"],
"mtu": previous["mtu"],
"keepalive": previous["keepalive"],
"preshared_key": previous["preshared_key"],
"advanced_security": previous["advanced_security"],
"allowed_ip": previous["allowed_ip"]
}).where(
self.configuration.peersTable.c.id == self.id
)
)
if pskExist and os.path.exists(uid):
os.remove(uid)
if isinstance(exc, subprocess.CalledProcessError):
return False, str(exc)
return False, str(exc)
self.configuration.getPeers()
return True, None
except subprocess.CalledProcessError as exc:

View File

@ -6,7 +6,7 @@ from flask import current_app
from .PeerJobs import PeerJobs
from .AmneziaWGPeer import AmneziaWGPeer
from .PeerShareLinks import PeerShareLinks
from .Utilities import RegexMatch, WgQuick, WgSetPeerAllowedIps
from .Utilities import RegexMatch
from .WireguardConfiguration import WireguardConfiguration
from .DashboardWebHooks import DashboardWebHooks
@ -293,16 +293,13 @@ class AmneziaWireguardConfiguration(WireguardConfiguration):
with open(uid, "w+") as f:
f.write(p['preshared_key'])
WgSetPeerAllowedIps(
self.Protocol,
self.Name,
self._wg_set_peer_allowed_ips(
p['id'],
p['allowed_ip'],
uid if presharedKeyExist else None
)
if presharedKeyExist:
os.remove(uid)
WgQuick(self.Protocol, "save", self.Name)
self._wg_quick_save()
self.getPeers()
for p in peers:
p = self.searchPeer(p['id'])

View File

@ -11,8 +11,7 @@ import jinja2
import sqlalchemy as db
from .PeerJob import PeerJob
from .PeerShareLink import PeerShareLink
from .Utilities import GenerateWireguardPublicKey, ValidateIPAddressesWithRange, ValidateDNSAddress, \
WgQuick, WgSetPeerAllowedIps
from .Utilities import GenerateWireguardPublicKey, ValidateIPAddressesWithRange, ValidateDNSAddress
class Peer:
@ -95,21 +94,12 @@ class Peer:
with open(uid, "w+") as f:
f.write(preshared_key)
psk_path = uid if pskExist else "/dev/null"
updateAllowedIp = WgSetPeerAllowedIps(
self.configuration.Protocol,
self.configuration.Name,
self.id,
allowed_ip,
psk_path
)
if pskExist: os.remove(uid)
if len(updateAllowedIp.decode().strip("\n")) != 0:
return False, "Update peer failed when updating Allowed IPs"
saveConfig = WgQuick(self.configuration.Protocol, "save", self.configuration.Name)
if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'):
return False, "Update peer failed when saving the configuration"
with self.configuration.engine.begin() as conn:
previous = conn.execute(
self.configuration.peersTable.select().where(
self.configuration.peersTable.c.id == self.id
)
).mappings().fetchone()
conn.execute(
self.configuration.peersTable.update().values({
"name": name,
@ -118,11 +108,45 @@ class Peer:
"endpoint_allowed_ip": endpoint_allowed_ip,
"mtu": mtu,
"keepalive": keepalive,
"preshared_key": preshared_key
"preshared_key": preshared_key,
"allowed_ip": allowed_ip.replace(" ", "")
}).where(
self.configuration.peersTable.c.id == self.id
)
)
try:
updateAllowedIp = self.configuration._wg_set_peer_allowed_ips(
self.id,
psk_path
)
if pskExist: os.remove(uid)
if len(updateAllowedIp.decode().strip("\n")) != 0:
raise subprocess.CalledProcessError(1, "wg set peer")
saveConfig = self.configuration._wg_quick_save()
if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'):
raise subprocess.CalledProcessError(1, "wg-quick save")
except (subprocess.CalledProcessError, ValueError) as exc:
with self.configuration.engine.begin() as conn:
if previous:
conn.execute(
self.configuration.peersTable.update().values({
"name": previous["name"],
"private_key": previous["private_key"],
"DNS": previous["DNS"],
"endpoint_allowed_ip": previous["endpoint_allowed_ip"],
"mtu": previous["mtu"],
"keepalive": previous["keepalive"],
"preshared_key": previous["preshared_key"],
"allowed_ip": previous["allowed_ip"]
}).where(
self.configuration.peersTable.c.id == self.id
)
)
if pskExist and os.path.exists(uid):
os.remove(uid)
if isinstance(exc, subprocess.CalledProcessError):
return False, str(exc)
return False, str(exc)
return True, None
except subprocess.CalledProcessError as exc:
return False, exc.output.decode("UTF-8").strip()

View File

@ -68,127 +68,30 @@ def ValidateEndpointAllowedIPs(IPs) -> tuple[bool, str] | tuple[bool, None]:
return False, str(e)
return True, None
_ALLOWED_PROTOCOLS = {
"wg": {
"exe": ("/usr/sbin/wg", "/usr/bin/wg"),
"quick": ("/usr/sbin/wg-quick", "/usr/bin/wg-quick"),
},
"awg": {
"exe": ("/usr/sbin/awg", "/usr/bin/awg"),
"quick": ("/usr/sbin/awg-quick", "/usr/bin/awg-quick"),
},
}
_ALLOWED_SUDO = ("/usr/sbin/sudo", "/usr/bin/sudo")
_IFACE_RE = re.compile(r"^[A-Za-z0-9_.-]{1,15}$")
_PEER_RE = re.compile(r"^[A-Za-z0-9+/=]{32,64}$")
_WG_EXE = ("/usr/sbin/wg", "/usr/bin/wg")
_WG_QUICK_EXE = ("/usr/sbin/wg-quick", "/usr/bin/wg-quick")
def _resolve_executable(protocol: str, quick: bool) -> str:
if protocol not in _ALLOWED_PROTOCOLS:
raise ValueError(f"Unsupported protocol: {protocol}")
key = "quick" if quick else "exe"
candidates = _ALLOWED_PROTOCOLS[protocol][key]
for path in candidates:
def _resolve_wg_exe() -> str:
for path in _WG_EXE:
if os.path.exists(path):
return path
fallback = shutil.which(f"{protocol}-quick" if quick else protocol)
fallback = shutil.which("wg")
if fallback:
fallback = os.path.realpath(fallback)
if fallback in candidates:
if fallback in _WG_EXE:
return fallback
raise FileNotFoundError(f"{protocol} binary not found in allowed paths")
def _resolve_sudo() -> str:
for path in _ALLOWED_SUDO:
if os.path.exists(path):
return path
fallback = shutil.which("sudo")
if fallback:
fallback = os.path.realpath(fallback)
if fallback in _ALLOWED_SUDO:
return fallback
raise FileNotFoundError("sudo not found in allowed paths")
def _validate_interface(name: str) -> str:
if not name or not _IFACE_RE.fullmatch(name):
raise ValueError(f"Invalid interface name: {name}")
return name
def _validate_peer_id(peer_id: str) -> str:
if not peer_id or not _PEER_RE.fullmatch(peer_id):
raise ValueError("Invalid peer public key")
return peer_id
def _normalize_allowed_ips(allowed_ips: str) -> str:
if allowed_ips is None:
raise ValueError("AllowedIPs is required")
cleaned = str(allowed_ips).replace(" ", "")
ok, err = ValidateEndpointAllowedIPs(cleaned)
if not ok:
raise ValueError(err or "Invalid AllowedIPs")
return cleaned
def _apply_sudo(cmd: list[str], require_root: bool) -> list[str]:
if require_root and os.geteuid() != 0:
sudo_path = _resolve_sudo()
return [sudo_path, "--non-interactive"] + cmd
return cmd
def WgShow(protocol: str, interface: str, field: str) -> bytes:
if field not in ("transfer", "endpoints", "latest-handshakes"):
raise ValueError(f"Unsupported show field: {field}")
exe = _resolve_executable(protocol, quick=False)
iface = _validate_interface(interface)
cmd = _apply_sudo([exe, "show", iface, field], require_root=True)
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def WgQuick(protocol: str, action: str, interface: str) -> bytes:
if action not in ("up", "down", "save"):
raise ValueError(f"Unsupported wg-quick action: {action}")
exe = _resolve_executable(protocol, quick=True)
iface = _validate_interface(interface)
cmd = _apply_sudo([exe, action, iface], require_root=True)
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def WgSetPeerAllowedIps(protocol: str, interface: str, peer_id: str,
allowed_ips: str, preshared_key_path: str | None = None) -> bytes:
exe = _resolve_executable(protocol, quick=False)
iface = _validate_interface(interface)
peer = _validate_peer_id(peer_id)
allowed = _normalize_allowed_ips(allowed_ips)
cmd = [exe, "set", iface, "peer", peer, "allowed-ips", allowed]
if preshared_key_path:
cmd += ["preshared-key", preshared_key_path]
cmd = _apply_sudo(cmd, require_root=True)
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def WgPeerRemove(protocol: str, interface: str, peer_id: str) -> bytes:
exe = _resolve_executable(protocol, quick=False)
iface = _validate_interface(interface)
peer = _validate_peer_id(peer_id)
cmd = _apply_sudo([exe, "set", iface, "peer", peer, "remove"], require_root=True)
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
raise FileNotFoundError("wg binary not found in allowed paths")
def WgPubkey(private_key: bytes) -> bytes:
exe = _resolve_executable("wg", quick=False)
cmd = [exe, "pubkey"]
return subprocess.check_output(cmd, input=private_key, stderr=subprocess.STDOUT)
exe = _resolve_wg_exe()
return subprocess.check_output([exe, "pubkey"], input=private_key, stderr=subprocess.STDOUT)
def WgGenkey() -> bytes:
exe = _resolve_executable("wg", quick=False)
cmd = [exe, "genkey"]
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
exe = _resolve_wg_exe()
return subprocess.check_output([exe, "genkey"], stderr=subprocess.STDOUT)
def GenerateWireguardPublicKey(privateKey: str) -> tuple[bool, str] | tuple[bool, None]:
try:

View File

@ -16,12 +16,26 @@ from .Peer import Peer
from .PeerJobs import PeerJobs
from .PeerShareLinks import PeerShareLinks
from .Utilities import StringToBoolean, GenerateWireguardPublicKey, RegexMatch, ValidateDNSAddress, \
ValidateEndpointAllowedIPs, WgShow, WgQuick, WgSetPeerAllowedIps, WgPeerRemove
ValidateEndpointAllowedIPs
from .WireguardConfigurationInfo import WireguardConfigurationInfo, PeerGroupsClass
from .DashboardWebHooks import DashboardWebHooks
class WireguardConfiguration:
_WG_BINARIES = {
"wg": {
"exe": ("/usr/sbin/wg", "/usr/bin/wg"),
"quick": ("/usr/sbin/wg-quick", "/usr/bin/wg-quick"),
},
"awg": {
"exe": ("/usr/sbin/awg", "/usr/bin/awg"),
"quick": ("/usr/sbin/awg-quick", "/usr/bin/awg-quick"),
},
}
_SUDO_BINARIES = ("/usr/sbin/sudo", "/usr/bin/sudo")
_IFACE_RE = re.compile(r"^[A-Za-z0-9_.-]{1,15}$")
_PEER_RE = re.compile(r"^[A-Za-z0-9+/=]{32,64}$")
class InvalidConfigurationFileException(Exception):
def __init__(self, m):
self.message = m
@ -142,6 +156,122 @@ class WireguardConfiguration:
self.addAutostart()
def _resolve_executable(self, quick: bool) -> str:
if self.Protocol not in self._WG_BINARIES:
raise ValueError(f"Unsupported protocol: {self.Protocol}")
key = "quick" if quick else "exe"
candidates = self._WG_BINARIES[self.Protocol][key]
for path in candidates:
if os.path.exists(path):
return path
fallback = shutil.which(f"{self.Protocol}-quick" if quick else self.Protocol)
if fallback:
fallback = os.path.realpath(fallback)
if fallback in candidates:
return fallback
raise FileNotFoundError(f"{self.Protocol} binary not found in allowed paths")
def _resolve_sudo(self) -> str:
for path in self._SUDO_BINARIES:
if os.path.exists(path):
return path
fallback = shutil.which("sudo")
if fallback:
fallback = os.path.realpath(fallback)
if fallback in self._SUDO_BINARIES:
return fallback
raise FileNotFoundError("sudo not found in allowed paths")
def _apply_sudo(self, cmd: list[str]) -> list[str]:
if os.geteuid() != 0:
sudo_path = self._resolve_sudo()
return [sudo_path, "--non-interactive"] + cmd
return cmd
def _validate_interface(self) -> str:
if not self.Name or not self._IFACE_RE.fullmatch(self.Name):
raise ValueError(f"Invalid interface name: {self.Name}")
return self.Name
def _validate_peer_id(self, peer_id: str) -> str:
if not peer_id or not self._PEER_RE.fullmatch(peer_id):
raise ValueError("Invalid peer public key")
return peer_id
def _normalize_allowed_ips(self, allowed_ips: str) -> str:
cleaned = str(allowed_ips or "").replace(" ", "")
ok, err = ValidateEndpointAllowedIPs(cleaned)
if not ok:
raise ValueError(err or "Invalid AllowedIPs")
return cleaned
def _wg_show_transfer(self) -> bytes:
exe = self._resolve_executable(quick=False)
iface = self._validate_interface()
cmd = self._apply_sudo([exe, "show", iface, "transfer"])
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def _wg_show_endpoints(self) -> bytes:
exe = self._resolve_executable(quick=False)
iface = self._validate_interface()
cmd = self._apply_sudo([exe, "show", iface, "endpoints"])
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def _wg_show_latest_handshakes(self) -> bytes:
exe = self._resolve_executable(quick=False)
iface = self._validate_interface()
cmd = self._apply_sudo([exe, "show", iface, "latest-handshakes"])
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def _wg_quick_up(self) -> bytes:
exe = self._resolve_executable(quick=True)
iface = self._validate_interface()
cmd = self._apply_sudo([exe, "up", iface])
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def _wg_quick_down(self) -> bytes:
exe = self._resolve_executable(quick=True)
iface = self._validate_interface()
cmd = self._apply_sudo([exe, "down", iface])
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def _wg_quick_save(self) -> bytes:
exe = self._resolve_executable(quick=True)
iface = self._validate_interface()
cmd = self._apply_sudo([exe, "save", iface])
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def _wg_set_peer_allowed_ips(self, peer_id: str,
preshared_key_path: str | None = None) -> bytes:
exe = self._resolve_executable(quick=False)
iface = self._validate_interface()
with self.engine.connect() as conn:
row = conn.execute(
self.peersTable.select().where(self.peersTable.c.id == peer_id)
).mappings().fetchone()
if not row:
raise ValueError("Peer not found")
peer = self._validate_peer_id(row["id"])
allowed = self._normalize_allowed_ips(row.get("allowed_ip"))
cmd = [exe, "set", iface, "peer", peer, "allowed-ips", allowed]
if preshared_key_path:
cmd += ["preshared-key", preshared_key_path]
cmd = self._apply_sudo(cmd)
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def _wg_peer_remove(self, peer_id: str) -> bytes:
exe = self._resolve_executable(quick=False)
iface = self._validate_interface()
with self.engine.connect() as conn:
row = conn.execute(
self.peersTable.select().where(self.peersTable.c.id == peer_id)
).mappings().fetchone()
if not row:
raise ValueError("Peer not found")
peer = self._validate_peer_id(row["id"])
cmd = self._apply_sudo([exe, "set", iface, "peer", peer, "remove"])
return subprocess.check_output(cmd, stderr=subprocess.STDOUT)
def __getProtocolPath(self) -> str:
_, path = self.DashboardConfig.GetConfig("Server", "wg_conf_path") if self.Protocol == "wg" \
else self.DashboardConfig.GetConfig("Server", "awg_conf_path")
@ -560,16 +690,13 @@ class WireguardConfiguration:
with open(uid, "w+") as f:
f.write(p['preshared_key'])
WgSetPeerAllowedIps(
self.Protocol,
self.Name,
self._wg_set_peer_allowed_ips(
p['id'],
p['allowed_ip'],
uid if presharedKeyExist else None
)
if presharedKeyExist:
os.remove(uid)
WgQuick(self.Protocol, "save", self.Name)
self._wg_quick_save()
self.getPeers()
for p in peers:
p = self.searchPeer(p['id'])
@ -619,11 +746,8 @@ class WireguardConfiguration:
with open(uid, "w+") as f:
f.write(restrictedPeer['preshared_key'])
WgSetPeerAllowedIps(
self.Protocol,
self.Name,
self._wg_set_peer_allowed_ips(
restrictedPeer['id'],
restrictedPeer['allowed_ip'],
uid if presharedKeyExist else None
)
if presharedKeyExist: os.remove(uid)
@ -645,7 +769,7 @@ class WireguardConfiguration:
found, pf = self.searchPeer(p)
if found:
try:
WgPeerRemove(self.Protocol, self.Name, pf.id)
self._wg_peer_remove(pf.id)
conn.execute(
self.peersRestrictedTable.insert().from_select(
[c.name for c in self.peersTable.columns],
@ -696,7 +820,7 @@ class WireguardConfiguration:
AllPeerShareLinks.updateLinkExpireDate(shareLink.ShareID, datetime.now())
if found:
try:
WgPeerRemove(self.Protocol, self.Name, pf.id)
self._wg_peer_remove(pf.id)
conn.execute(
self.peersTable.delete().where(
self.peersTable.columns.id == pf.id
@ -726,7 +850,7 @@ class WireguardConfiguration:
def __wgSave(self) -> tuple[bool, str] | tuple[bool, None]:
try:
WgQuick(self.Protocol, "save", self.Name)
self._wg_quick_save()
return True, None
except subprocess.CalledProcessError as e:
return False, str(e)
@ -735,7 +859,7 @@ class WireguardConfiguration:
if not self.getStatus():
self.toggleConfiguration()
try:
latestHandshake = WgShow(self.Protocol, self.Name, "latest-handshakes")
latestHandshake = self._wg_show_latest_handshakes()
except subprocess.CalledProcessError:
return "stopped"
latestHandshake = latestHandshake.decode("UTF-8").split()
@ -774,7 +898,7 @@ class WireguardConfiguration:
if not self.getStatus():
self.toggleConfiguration()
# try:
data_usage = WgShow(self.Protocol, self.Name, "transfer")
data_usage = self._wg_show_transfer()
data_usage = data_usage.decode("UTF-8").split("\n")
data_usage = [p.split("\t") for p in data_usage]
@ -830,7 +954,7 @@ class WireguardConfiguration:
if not self.getStatus():
self.toggleConfiguration()
try:
data_usage = WgShow(self.Protocol, self.Name, "endpoints")
data_usage = self._wg_show_endpoints()
except subprocess.CalledProcessError:
return "stopped"
data_usage = data_usage.decode("UTF-8").split()
@ -850,13 +974,13 @@ class WireguardConfiguration:
self.getStatus()
if self.Status:
try:
check = WgQuick(self.Protocol, "down", self.Name)
check = self._wg_quick_down()
self.removeAutostart()
except subprocess.CalledProcessError as exc:
return False, str(exc.output.strip().decode("utf-8"))
else:
try:
check = WgQuick(self.Protocol, "up", self.Name)
check = self._wg_quick_up()
self.addAutostart()
except subprocess.CalledProcessError as exc:
return False, str(exc.output.strip().decode("utf-8"))