diff --git a/src/modules/AmneziaWGPeer.py b/src/modules/AmneziaWGPeer.py index 70e5bbdb..203db890 100644 --- a/src/modules/AmneziaWGPeer.py +++ b/src/modules/AmneziaWGPeer.py @@ -5,8 +5,7 @@ import subprocess import uuid from .Peer import Peer -from .Utilities import ValidateIPAddressesWithRange, ValidateDNSAddress, GenerateWireguardPublicKey, \ - WgQuick, WgSetPeerAllowedIps +from .Utilities import ValidateIPAddressesWithRange, ValidateDNSAddress, GenerateWireguardPublicKey class AmneziaWGPeer(Peer): @@ -59,23 +58,12 @@ class AmneziaWGPeer(Peer): with open(uid, "w+") as f: f.write(preshared_key) psk_path = uid if pskExist else "/dev/null" - updateAllowedIp = WgSetPeerAllowedIps( - self.configuration.Protocol, - self.configuration.Name, - self.id, - allowed_ip, - psk_path - ) - - if pskExist: os.remove(uid) - - if len(updateAllowedIp.decode().strip("\n")) != 0: - return False, "Update peer failed when updating Allowed IPs" - saveConfig = WgQuick(self.configuration.Protocol, "save", self.configuration.Name) - if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'): - return False, "Update peer failed when saving the configuration" - with self.configuration.engine.begin() as conn: + previous = conn.execute( + self.configuration.peersTable.select().where( + self.configuration.peersTable.c.id == self.id + ) + ).mappings().fetchone() conn.execute( self.configuration.peersTable.update().values({ "name": name, @@ -85,11 +73,48 @@ class AmneziaWGPeer(Peer): "mtu": mtu, "keepalive": keepalive, "preshared_key": preshared_key, - "advanced_security": advanced_security + "advanced_security": advanced_security, + "allowed_ip": allowed_ip.replace(" ", "") }).where( self.configuration.peersTable.c.id == self.id ) ) + try: + updateAllowedIp = self.configuration._wg_set_peer_allowed_ips( + self.id, + psk_path + ) + + if pskExist: os.remove(uid) + + if len(updateAllowedIp.decode().strip("\n")) != 0: + raise subprocess.CalledProcessError(1, "wg set peer") + saveConfig = self.configuration._wg_quick_save() + if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'): + raise subprocess.CalledProcessError(1, "wg-quick save") + except (subprocess.CalledProcessError, ValueError) as exc: + with self.configuration.engine.begin() as conn: + if previous: + conn.execute( + self.configuration.peersTable.update().values({ + "name": previous["name"], + "private_key": previous["private_key"], + "DNS": previous["DNS"], + "endpoint_allowed_ip": previous["endpoint_allowed_ip"], + "mtu": previous["mtu"], + "keepalive": previous["keepalive"], + "preshared_key": previous["preshared_key"], + "advanced_security": previous["advanced_security"], + "allowed_ip": previous["allowed_ip"] + }).where( + self.configuration.peersTable.c.id == self.id + ) + ) + if pskExist and os.path.exists(uid): + os.remove(uid) + if isinstance(exc, subprocess.CalledProcessError): + return False, str(exc) + return False, str(exc) self.configuration.getPeers() return True, None except subprocess.CalledProcessError as exc: diff --git a/src/modules/AmneziaWireguardConfiguration.py b/src/modules/AmneziaWireguardConfiguration.py index 56b9a411..839fe503 100644 --- a/src/modules/AmneziaWireguardConfiguration.py +++ b/src/modules/AmneziaWireguardConfiguration.py @@ -6,7 +6,7 @@ from flask import current_app from .PeerJobs import PeerJobs from .AmneziaWGPeer import AmneziaWGPeer from .PeerShareLinks import PeerShareLinks -from .Utilities import RegexMatch, WgQuick, WgSetPeerAllowedIps +from .Utilities import RegexMatch from .WireguardConfiguration import WireguardConfiguration from .DashboardWebHooks import DashboardWebHooks @@ -293,16 +293,13 @@ class AmneziaWireguardConfiguration(WireguardConfiguration): with open(uid, "w+") as f: f.write(p['preshared_key']) - WgSetPeerAllowedIps( - self.Protocol, - self.Name, + self._wg_set_peer_allowed_ips( p['id'], - p['allowed_ip'], uid if presharedKeyExist else None ) if presharedKeyExist: os.remove(uid) - WgQuick(self.Protocol, "save", self.Name) + self._wg_quick_save() self.getPeers() for p in peers: p = self.searchPeer(p['id']) diff --git a/src/modules/Peer.py b/src/modules/Peer.py index f37e8a7c..87b6f058 100644 --- a/src/modules/Peer.py +++ b/src/modules/Peer.py @@ -11,8 +11,7 @@ import jinja2 import sqlalchemy as db from .PeerJob import PeerJob from .PeerShareLink import PeerShareLink -from .Utilities import GenerateWireguardPublicKey, ValidateIPAddressesWithRange, ValidateDNSAddress, \ - WgQuick, WgSetPeerAllowedIps +from .Utilities import GenerateWireguardPublicKey, ValidateIPAddressesWithRange, ValidateDNSAddress class Peer: @@ -95,21 +94,12 @@ class Peer: with open(uid, "w+") as f: f.write(preshared_key) psk_path = uid if pskExist else "/dev/null" - updateAllowedIp = WgSetPeerAllowedIps( - self.configuration.Protocol, - self.configuration.Name, - self.id, - allowed_ip, - psk_path - ) - - if pskExist: os.remove(uid) - if len(updateAllowedIp.decode().strip("\n")) != 0: - return False, "Update peer failed when updating Allowed IPs" - saveConfig = WgQuick(self.configuration.Protocol, "save", self.configuration.Name) - if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'): - return False, "Update peer failed when saving the configuration" with self.configuration.engine.begin() as conn: + previous = conn.execute( + self.configuration.peersTable.select().where( + self.configuration.peersTable.c.id == self.id + ) + ).mappings().fetchone() conn.execute( self.configuration.peersTable.update().values({ "name": name, @@ -118,11 +108,45 @@ class Peer: "endpoint_allowed_ip": endpoint_allowed_ip, "mtu": mtu, "keepalive": keepalive, - "preshared_key": preshared_key + "preshared_key": preshared_key, + "allowed_ip": allowed_ip.replace(" ", "") }).where( self.configuration.peersTable.c.id == self.id ) ) + try: + updateAllowedIp = self.configuration._wg_set_peer_allowed_ips( + self.id, + psk_path + ) + if pskExist: os.remove(uid) + if len(updateAllowedIp.decode().strip("\n")) != 0: + raise subprocess.CalledProcessError(1, "wg set peer") + saveConfig = self.configuration._wg_quick_save() + if f"wg showconf {self.configuration.Name}" not in saveConfig.decode().strip('\n'): + raise subprocess.CalledProcessError(1, "wg-quick save") + except (subprocess.CalledProcessError, ValueError) as exc: + with self.configuration.engine.begin() as conn: + if previous: + conn.execute( + self.configuration.peersTable.update().values({ + "name": previous["name"], + "private_key": previous["private_key"], + "DNS": previous["DNS"], + "endpoint_allowed_ip": previous["endpoint_allowed_ip"], + "mtu": previous["mtu"], + "keepalive": previous["keepalive"], + "preshared_key": previous["preshared_key"], + "allowed_ip": previous["allowed_ip"] + }).where( + self.configuration.peersTable.c.id == self.id + ) + ) + if pskExist and os.path.exists(uid): + os.remove(uid) + if isinstance(exc, subprocess.CalledProcessError): + return False, str(exc) + return False, str(exc) return True, None except subprocess.CalledProcessError as exc: return False, exc.output.decode("UTF-8").strip() diff --git a/src/modules/Utilities.py b/src/modules/Utilities.py index 60bf2c10..5c4ccaa7 100644 --- a/src/modules/Utilities.py +++ b/src/modules/Utilities.py @@ -68,127 +68,30 @@ def ValidateEndpointAllowedIPs(IPs) -> tuple[bool, str] | tuple[bool, None]: return False, str(e) return True, None -_ALLOWED_PROTOCOLS = { - "wg": { - "exe": ("/usr/sbin/wg", "/usr/bin/wg"), - "quick": ("/usr/sbin/wg-quick", "/usr/bin/wg-quick"), - }, - "awg": { - "exe": ("/usr/sbin/awg", "/usr/bin/awg"), - "quick": ("/usr/sbin/awg-quick", "/usr/bin/awg-quick"), - }, -} -_ALLOWED_SUDO = ("/usr/sbin/sudo", "/usr/bin/sudo") -_IFACE_RE = re.compile(r"^[A-Za-z0-9_.-]{1,15}$") -_PEER_RE = re.compile(r"^[A-Za-z0-9+/=]{32,64}$") +_WG_EXE = ("/usr/sbin/wg", "/usr/bin/wg") +_WG_QUICK_EXE = ("/usr/sbin/wg-quick", "/usr/bin/wg-quick") -def _resolve_executable(protocol: str, quick: bool) -> str: - if protocol not in _ALLOWED_PROTOCOLS: - raise ValueError(f"Unsupported protocol: {protocol}") - key = "quick" if quick else "exe" - candidates = _ALLOWED_PROTOCOLS[protocol][key] - for path in candidates: +def _resolve_wg_exe() -> str: + for path in _WG_EXE: if os.path.exists(path): return path - fallback = shutil.which(f"{protocol}-quick" if quick else protocol) + fallback = shutil.which("wg") if fallback: fallback = os.path.realpath(fallback) - if fallback in candidates: + if fallback in _WG_EXE: return fallback - raise FileNotFoundError(f"{protocol} binary not found in allowed paths") - - -def _resolve_sudo() -> str: - for path in _ALLOWED_SUDO: - if os.path.exists(path): - return path - fallback = shutil.which("sudo") - if fallback: - fallback = os.path.realpath(fallback) - if fallback in _ALLOWED_SUDO: - return fallback - raise FileNotFoundError("sudo not found in allowed paths") - - -def _validate_interface(name: str) -> str: - if not name or not _IFACE_RE.fullmatch(name): - raise ValueError(f"Invalid interface name: {name}") - return name - - -def _validate_peer_id(peer_id: str) -> str: - if not peer_id or not _PEER_RE.fullmatch(peer_id): - raise ValueError("Invalid peer public key") - return peer_id - - -def _normalize_allowed_ips(allowed_ips: str) -> str: - if allowed_ips is None: - raise ValueError("AllowedIPs is required") - cleaned = str(allowed_ips).replace(" ", "") - ok, err = ValidateEndpointAllowedIPs(cleaned) - if not ok: - raise ValueError(err or "Invalid AllowedIPs") - return cleaned - - -def _apply_sudo(cmd: list[str], require_root: bool) -> list[str]: - if require_root and os.geteuid() != 0: - sudo_path = _resolve_sudo() - return [sudo_path, "--non-interactive"] + cmd - return cmd - - -def WgShow(protocol: str, interface: str, field: str) -> bytes: - if field not in ("transfer", "endpoints", "latest-handshakes"): - raise ValueError(f"Unsupported show field: {field}") - exe = _resolve_executable(protocol, quick=False) - iface = _validate_interface(interface) - cmd = _apply_sudo([exe, "show", iface, field], require_root=True) - return subprocess.check_output(cmd, stderr=subprocess.STDOUT) - - -def WgQuick(protocol: str, action: str, interface: str) -> bytes: - if action not in ("up", "down", "save"): - raise ValueError(f"Unsupported wg-quick action: {action}") - exe = _resolve_executable(protocol, quick=True) - iface = _validate_interface(interface) - cmd = _apply_sudo([exe, action, iface], require_root=True) - return subprocess.check_output(cmd, stderr=subprocess.STDOUT) - - -def WgSetPeerAllowedIps(protocol: str, interface: str, peer_id: str, - allowed_ips: str, preshared_key_path: str | None = None) -> bytes: - exe = _resolve_executable(protocol, quick=False) - iface = _validate_interface(interface) - peer = _validate_peer_id(peer_id) - allowed = _normalize_allowed_ips(allowed_ips) - cmd = [exe, "set", iface, "peer", peer, "allowed-ips", allowed] - if preshared_key_path: - cmd += ["preshared-key", preshared_key_path] - cmd = _apply_sudo(cmd, require_root=True) - return subprocess.check_output(cmd, stderr=subprocess.STDOUT) - - -def WgPeerRemove(protocol: str, interface: str, peer_id: str) -> bytes: - exe = _resolve_executable(protocol, quick=False) - iface = _validate_interface(interface) - peer = _validate_peer_id(peer_id) - cmd = _apply_sudo([exe, "set", iface, "peer", peer, "remove"], require_root=True) - return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + raise FileNotFoundError("wg binary not found in allowed paths") def WgPubkey(private_key: bytes) -> bytes: - exe = _resolve_executable("wg", quick=False) - cmd = [exe, "pubkey"] - return subprocess.check_output(cmd, input=private_key, stderr=subprocess.STDOUT) + exe = _resolve_wg_exe() + return subprocess.check_output([exe, "pubkey"], input=private_key, stderr=subprocess.STDOUT) def WgGenkey() -> bytes: - exe = _resolve_executable("wg", quick=False) - cmd = [exe, "genkey"] - return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + exe = _resolve_wg_exe() + return subprocess.check_output([exe, "genkey"], stderr=subprocess.STDOUT) def GenerateWireguardPublicKey(privateKey: str) -> tuple[bool, str] | tuple[bool, None]: try: diff --git a/src/modules/WireguardConfiguration.py b/src/modules/WireguardConfiguration.py index c5d261e9..3e3b2dc8 100644 --- a/src/modules/WireguardConfiguration.py +++ b/src/modules/WireguardConfiguration.py @@ -16,12 +16,26 @@ from .Peer import Peer from .PeerJobs import PeerJobs from .PeerShareLinks import PeerShareLinks from .Utilities import StringToBoolean, GenerateWireguardPublicKey, RegexMatch, ValidateDNSAddress, \ - ValidateEndpointAllowedIPs, WgShow, WgQuick, WgSetPeerAllowedIps, WgPeerRemove + ValidateEndpointAllowedIPs from .WireguardConfigurationInfo import WireguardConfigurationInfo, PeerGroupsClass from .DashboardWebHooks import DashboardWebHooks class WireguardConfiguration: + _WG_BINARIES = { + "wg": { + "exe": ("/usr/sbin/wg", "/usr/bin/wg"), + "quick": ("/usr/sbin/wg-quick", "/usr/bin/wg-quick"), + }, + "awg": { + "exe": ("/usr/sbin/awg", "/usr/bin/awg"), + "quick": ("/usr/sbin/awg-quick", "/usr/bin/awg-quick"), + }, + } + _SUDO_BINARIES = ("/usr/sbin/sudo", "/usr/bin/sudo") + _IFACE_RE = re.compile(r"^[A-Za-z0-9_.-]{1,15}$") + _PEER_RE = re.compile(r"^[A-Za-z0-9+/=]{32,64}$") + class InvalidConfigurationFileException(Exception): def __init__(self, m): self.message = m @@ -142,6 +156,122 @@ class WireguardConfiguration: self.addAutostart() + def _resolve_executable(self, quick: bool) -> str: + if self.Protocol not in self._WG_BINARIES: + raise ValueError(f"Unsupported protocol: {self.Protocol}") + key = "quick" if quick else "exe" + candidates = self._WG_BINARIES[self.Protocol][key] + for path in candidates: + if os.path.exists(path): + return path + fallback = shutil.which(f"{self.Protocol}-quick" if quick else self.Protocol) + if fallback: + fallback = os.path.realpath(fallback) + if fallback in candidates: + return fallback + raise FileNotFoundError(f"{self.Protocol} binary not found in allowed paths") + + def _resolve_sudo(self) -> str: + for path in self._SUDO_BINARIES: + if os.path.exists(path): + return path + fallback = shutil.which("sudo") + if fallback: + fallback = os.path.realpath(fallback) + if fallback in self._SUDO_BINARIES: + return fallback + raise FileNotFoundError("sudo not found in allowed paths") + + def _apply_sudo(self, cmd: list[str]) -> list[str]: + if os.geteuid() != 0: + sudo_path = self._resolve_sudo() + return [sudo_path, "--non-interactive"] + cmd + return cmd + + def _validate_interface(self) -> str: + if not self.Name or not self._IFACE_RE.fullmatch(self.Name): + raise ValueError(f"Invalid interface name: {self.Name}") + return self.Name + + def _validate_peer_id(self, peer_id: str) -> str: + if not peer_id or not self._PEER_RE.fullmatch(peer_id): + raise ValueError("Invalid peer public key") + return peer_id + + def _normalize_allowed_ips(self, allowed_ips: str) -> str: + cleaned = str(allowed_ips or "").replace(" ", "") + ok, err = ValidateEndpointAllowedIPs(cleaned) + if not ok: + raise ValueError(err or "Invalid AllowedIPs") + return cleaned + + def _wg_show_transfer(self) -> bytes: + exe = self._resolve_executable(quick=False) + iface = self._validate_interface() + cmd = self._apply_sudo([exe, "show", iface, "transfer"]) + return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + def _wg_show_endpoints(self) -> bytes: + exe = self._resolve_executable(quick=False) + iface = self._validate_interface() + cmd = self._apply_sudo([exe, "show", iface, "endpoints"]) + return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + def _wg_show_latest_handshakes(self) -> bytes: + exe = self._resolve_executable(quick=False) + iface = self._validate_interface() + cmd = self._apply_sudo([exe, "show", iface, "latest-handshakes"]) + return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + def _wg_quick_up(self) -> bytes: + exe = self._resolve_executable(quick=True) + iface = self._validate_interface() + cmd = self._apply_sudo([exe, "up", iface]) + return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + def _wg_quick_down(self) -> bytes: + exe = self._resolve_executable(quick=True) + iface = self._validate_interface() + cmd = self._apply_sudo([exe, "down", iface]) + return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + def _wg_quick_save(self) -> bytes: + exe = self._resolve_executable(quick=True) + iface = self._validate_interface() + cmd = self._apply_sudo([exe, "save", iface]) + return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + def _wg_set_peer_allowed_ips(self, peer_id: str, + preshared_key_path: str | None = None) -> bytes: + exe = self._resolve_executable(quick=False) + iface = self._validate_interface() + with self.engine.connect() as conn: + row = conn.execute( + self.peersTable.select().where(self.peersTable.c.id == peer_id) + ).mappings().fetchone() + if not row: + raise ValueError("Peer not found") + peer = self._validate_peer_id(row["id"]) + allowed = self._normalize_allowed_ips(row.get("allowed_ip")) + cmd = [exe, "set", iface, "peer", peer, "allowed-ips", allowed] + if preshared_key_path: + cmd += ["preshared-key", preshared_key_path] + cmd = self._apply_sudo(cmd) + return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + + def _wg_peer_remove(self, peer_id: str) -> bytes: + exe = self._resolve_executable(quick=False) + iface = self._validate_interface() + with self.engine.connect() as conn: + row = conn.execute( + self.peersTable.select().where(self.peersTable.c.id == peer_id) + ).mappings().fetchone() + if not row: + raise ValueError("Peer not found") + peer = self._validate_peer_id(row["id"]) + cmd = self._apply_sudo([exe, "set", iface, "peer", peer, "remove"]) + return subprocess.check_output(cmd, stderr=subprocess.STDOUT) + def __getProtocolPath(self) -> str: _, path = self.DashboardConfig.GetConfig("Server", "wg_conf_path") if self.Protocol == "wg" \ else self.DashboardConfig.GetConfig("Server", "awg_conf_path") @@ -560,16 +690,13 @@ class WireguardConfiguration: with open(uid, "w+") as f: f.write(p['preshared_key']) - WgSetPeerAllowedIps( - self.Protocol, - self.Name, + self._wg_set_peer_allowed_ips( p['id'], - p['allowed_ip'], uid if presharedKeyExist else None ) if presharedKeyExist: os.remove(uid) - WgQuick(self.Protocol, "save", self.Name) + self._wg_quick_save() self.getPeers() for p in peers: p = self.searchPeer(p['id']) @@ -619,11 +746,8 @@ class WireguardConfiguration: with open(uid, "w+") as f: f.write(restrictedPeer['preshared_key']) - WgSetPeerAllowedIps( - self.Protocol, - self.Name, + self._wg_set_peer_allowed_ips( restrictedPeer['id'], - restrictedPeer['allowed_ip'], uid if presharedKeyExist else None ) if presharedKeyExist: os.remove(uid) @@ -645,7 +769,7 @@ class WireguardConfiguration: found, pf = self.searchPeer(p) if found: try: - WgPeerRemove(self.Protocol, self.Name, pf.id) + self._wg_peer_remove(pf.id) conn.execute( self.peersRestrictedTable.insert().from_select( [c.name for c in self.peersTable.columns], @@ -696,7 +820,7 @@ class WireguardConfiguration: AllPeerShareLinks.updateLinkExpireDate(shareLink.ShareID, datetime.now()) if found: try: - WgPeerRemove(self.Protocol, self.Name, pf.id) + self._wg_peer_remove(pf.id) conn.execute( self.peersTable.delete().where( self.peersTable.columns.id == pf.id @@ -726,7 +850,7 @@ class WireguardConfiguration: def __wgSave(self) -> tuple[bool, str] | tuple[bool, None]: try: - WgQuick(self.Protocol, "save", self.Name) + self._wg_quick_save() return True, None except subprocess.CalledProcessError as e: return False, str(e) @@ -735,7 +859,7 @@ class WireguardConfiguration: if not self.getStatus(): self.toggleConfiguration() try: - latestHandshake = WgShow(self.Protocol, self.Name, "latest-handshakes") + latestHandshake = self._wg_show_latest_handshakes() except subprocess.CalledProcessError: return "stopped" latestHandshake = latestHandshake.decode("UTF-8").split() @@ -774,7 +898,7 @@ class WireguardConfiguration: if not self.getStatus(): self.toggleConfiguration() # try: - data_usage = WgShow(self.Protocol, self.Name, "transfer") + data_usage = self._wg_show_transfer() data_usage = data_usage.decode("UTF-8").split("\n") data_usage = [p.split("\t") for p in data_usage] @@ -830,7 +954,7 @@ class WireguardConfiguration: if not self.getStatus(): self.toggleConfiguration() try: - data_usage = WgShow(self.Protocol, self.Name, "endpoints") + data_usage = self._wg_show_endpoints() except subprocess.CalledProcessError: return "stopped" data_usage = data_usage.decode("UTF-8").split() @@ -850,13 +974,13 @@ class WireguardConfiguration: self.getStatus() if self.Status: try: - check = WgQuick(self.Protocol, "down", self.Name) + check = self._wg_quick_down() self.removeAutostart() except subprocess.CalledProcessError as exc: return False, str(exc.output.strip().decode("utf-8")) else: try: - check = WgQuick(self.Protocol, "up", self.Name) + check = self._wg_quick_up() self.addAutostart() except subprocess.CalledProcessError as exc: return False, str(exc.output.strip().decode("utf-8"))