diff --git a/.github/docker/msvc-wine.dockerfile b/.github/docker/msvc-wine.dockerfile new file mode 100644 index 00000000..56d8a42d --- /dev/null +++ b/.github/docker/msvc-wine.dockerfile @@ -0,0 +1,83 @@ +# Cross-compiles Windows kernel drivers using msvc-wine. +# Toolchain: MSVC 17.13 + SDK/WDK 10.0.22621 +# Architectures: x86, x64, ARM, ARM64 +FROM alpine:3.21 AS msvc-wine +RUN apk add --no-cache git +RUN git clone https://github.com/mcha-forks/msvc-wine.git /msvc-wine \ + && cd /msvc-wine && git checkout 2d5f5a3 + +FROM registry.fedoraproject.org/fedora-minimal:44 AS wine + +COPY <<-'EOF' /etc/yum.repos.d/_copr:copr.fedorainfracloud.org:mochaa:wine.repo + [copr:copr.fedorainfracloud.org:mochaa:wine] + name=Copr repo for wine owned by mochaa + baseurl=https://download.copr.fedorainfracloud.org/results/mochaa/wine/fedora-$releasever-$basearch/ + type=rpm-md + skip_if_unavailable=True + gpgcheck=1 + gpgkey=https://download.copr.fedorainfracloud.org/results/mochaa/wine/pubkey.gpg + repo_gpgcheck=0 + enabled=1 + enabled_metadata=1 + + [copr:copr.fedorainfracloud.org:mochaa:wine:ml] + name=Copr repo for wine owned by mochaa (i386) + baseurl=https://download.copr.fedorainfracloud.org/results/mochaa/wine/fedora-$releasever-i386/ + type=rpm-md + skip_if_unavailable=True + gpgcheck=1 + gpgkey=https://download.copr.fedorainfracloud.org/results/mochaa/wine/pubkey.gpg + repo_gpgcheck=0 + cost=1100 + enabled=1 + enabled_metadata=1 +EOF + +RUN <<-EOF + set -xeu + microdnf install -y wine-core wine-core.i686 wine-mono + microdnf clean all +EOF + +RUN <<-EOF + set -xeu + wine wineboot -u + wine reg.exe add HKCU\\Software\\Wine\\Drivers /v Graphics /t REG_SZ /d null + wineserver -w +EOF + +WORKDIR /builddir + +FROM wine AS fetch-wdk +COPY --from=msvc-wine /msvc-wine/wdk-download.sh ./ +RUN WINEDEBUG=1 bash -x ./wdk-download.sh --cache wdk https://go.microsoft.com/fwlink/?linkid=2330411 + +FROM python:3.14-slim AS fetch-msvc +WORKDIR /builddir +COPY --from=msvc-wine /msvc-wine/vsdownload.py ./ +RUN PYTHONUNBUFFERED=1 ./vsdownload.py --accept-license --only-download --cache cache \ + --major=17 --msvc-version=17.13 --sdk-version=10.0.22621 --with-wdk-installer wdk/Installers \ + Microsoft.Component.MSBuild + +FROM wine AS builder +RUN <<-EOF + microdnf install -y git msitools perl + microdnf clean all +EOF +COPY --from=fetch-msvc /builddir/cache/ ./cache/ +COPY --from=fetch-wdk /builddir/wdk/Installers/ ./wdk/Installers/ +COPY --from=msvc-wine /msvc-wine/vsdownload.py ./ +COPY --from=msvc-wine /msvc-wine/patches/ ./patches/ +RUN PYTHONUNBUFFERED=1 python3 ./vsdownload.py --accept-license --cache cache --dest /opt/msvc \ + --major=17 --msvc-version=17.13 --sdk-version=10.0.22621 --with-wdk-installer wdk/Installers \ + Microsoft.Component.MSBuild +COPY --from=msvc-wine /msvc-wine/lowercase /msvc-wine/fixinclude /msvc-wine/install.sh /msvc-wine/msvctricks.cpp ./ +COPY --from=msvc-wine /msvc-wine/wrappers/ ./wrappers/ +RUN bash -x ./install.sh /opt/msvc +# WDK 22621 MSBuild targets reject Win32/ARM for km drivers (removed in Windows 11). +# The km libraries for these arches are present; only the validation blocks them. +RUN find -L /opt/msvc -iname 'windowsdriver.common.targets' \ + -exec sed -i '/not a valid architecture for Kernel mode/s/ int: + return int.from_bytes(data[offset:offset + 2], "little") + + +def read_u32(data: bytes, offset: int) -> int: + return int.from_bytes(data[offset:offset + 4], "little") + + +def canonicalize_signed_pe(path: Path) -> bytes: + data = bytearray(path.read_bytes()) + if len(data) < 0x40 or data[:2] != b"MZ": + raise PEFormatError(f"{path}: missing DOS header") + + pe_offset = read_u32(data, 0x3C) + if pe_offset + 24 > len(data) or data[pe_offset:pe_offset + 4] != b"PE\x00\x00": + raise PEFormatError(f"{path}: missing PE header") + + optional_offset = pe_offset + 24 + magic = read_u16(data, optional_offset) + if magic == 0x10B: + data_directory_offset = optional_offset + 96 + elif magic == 0x20B: + data_directory_offset = optional_offset + 112 + else: + raise PEFormatError(f"{path}: unsupported optional header magic 0x{magic:04x}") + + checksum_offset = optional_offset + 64 + if checksum_offset + 4 > len(data): + raise PEFormatError(f"{path}: truncated optional header") + data[checksum_offset:checksum_offset + 4] = b"\x00" * 4 + + security_directory_offset = data_directory_offset + 8 * 4 + if security_directory_offset + 8 > len(data): + raise PEFormatError(f"{path}: truncated data directories") + + certificate_offset = read_u32(data, security_directory_offset) + certificate_size = read_u32(data, security_directory_offset + 4) + data[security_directory_offset:security_directory_offset + 8] = b"\x00" * 8 + + if certificate_offset == 0 or certificate_size == 0: + return bytes(data) + + certificate_end = certificate_offset + certificate_size + if certificate_end > len(data): + raise PEFormatError(f"{path}: certificate table exceeds file size") + + return bytes(data[:certificate_offset] + data[certificate_end:]) + + +def canonical_sha256(path: Path) -> str: + return hashlib.sha256(canonicalize_signed_pe(path)).hexdigest() + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Compare two PE files after stripping Authenticode-only differences.", + ) + parser.add_argument("left", type=Path) + parser.add_argument("right", type=Path) + args = parser.parse_args() + + try: + left_hash = canonical_sha256(args.left) + right_hash = canonical_sha256(args.right) + except (OSError, PEFormatError) as exc: + print(exc, file=sys.stderr) + return 2 + + print(f"{args.left}: {left_hash}") + print(f"{args.right}: {right_hash}") + return 0 if left_hash == right_hash else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/.github/workflows/winredirect-driver-sync.yml b/.github/workflows/winredirect-driver-sync.yml new file mode 100644 index 00000000..1eee1dd6 --- /dev/null +++ b/.github/workflows/winredirect-driver-sync.yml @@ -0,0 +1,160 @@ +name: sync winredirect driver + +on: + pull_request: + branches: + - main + - dev + paths: + - '.github/docker/msvc-wine.dockerfile' + - '.github/scripts/compare_signed_pe.py' + - '.github/workflows/winredirect-driver-sync.yml' + - 'internal/winredirect/driver/**' + +permissions: + contents: write + +concurrency: + group: winredirect-driver-${{ github.event.pull_request.head.repo.full_name }}-${{ github.event.pull_request.head.ref }} + cancel-in-progress: true + +jobs: + sync: + name: Refresh bundled drivers + if: github.event.pull_request.head.repo.full_name == github.repository + runs-on: ubuntu-latest + steps: + - name: Checkout PR head + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.ref }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build msvc-wine image + uses: docker/build-push-action@v6 + with: + file: .github/docker/msvc-wine.dockerfile + load: true + tags: msvc-wine:local + cache-from: type=gha,scope=msvc-wine + cache-to: type=gha,mode=max,scope=msvc-wine + + - name: Print toolchain versions + run: | + docker run --rm msvc-wine:local /opt/msvc/bin/x64/cl 2>&1 | head -2 || true + echo "SDK/WDK include versions:" + docker run --rm msvc-wine:local ls /opt/msvc/kits/10/include/ + echo "WDK km lib architectures:" + docker run --rm msvc-wine:local sh -c 'ls /opt/msvc/kits/10/lib/*/km/ 2>/dev/null' || true + echo "WindowsDriver.common.targets (arch validation):" + docker run --rm msvc-wine:local sh -c 'find /opt/msvc -name "WindowsDriver.common.targets" -exec grep -n -i "valid architecture\|_SUPPORTED" {} +' 2>/dev/null || true + + - name: Build, verify reproducibility, and refresh tracked drivers + id: sync + run: | + set -euo pipefail + + arches=(x64 arm64 x86) + platforms=(x64 ARM64 Win32) + tracked=( + internal/winredirect/amd64/winredirect.sys + internal/winredirect/arm64/winredirect.sys + internal/winredirect/x86/winredirect.sys + ) + + changed=false + artifact_dir="${RUNNER_TEMP}/winredirect-driver-sync" + failure_dir="${artifact_dir}/failure" + mkdir -p "$artifact_dir" + + build_driver() { + docker run --rm \ + -v "$PWD:/src" -w /src \ + -e CL=/Brepro -e LINK=/Brepro \ + msvc-wine:local \ + /opt/msvc/bin/"$1"/msbuild \ + internal/winredirect/driver/winredirect.vcxproj \ + /t:Rebuild /p:Configuration=Release /p:SpectreMitigation=false /p:TrackFileAccess=false /v:minimal + } + + for i in "${!arches[@]}"; do + arch="${arches[$i]}" + platform="${platforms[$i]}" + tracked_file="${tracked[$i]}" + build_output="internal/winredirect/driver/build/${platform}/Release/winredirect.sys" + run1="${artifact_dir}/${arch}-run1.sys" + run2="${artifact_dir}/${arch}-run2.sys" + + echo "::group::Build ${arch} run 1" + build_driver "$arch" + cp "$build_output" "$run1" + echo "::endgroup::" + + echo "::group::Build ${arch} run 2" + build_driver "$arch" + cp "$build_output" "$run2" + echo "::endgroup::" + + echo "::group::Verify ${arch} reproducibility" + rc=0 + python3 .github/scripts/compare_signed_pe.py "$run1" "$run2" || rc=$? + case $rc in + 0) + echo "${arch} reproduced after stripping Authenticode data." + ;; + 1) + mkdir -p "$failure_dir" + cp "$run1" "${failure_dir}/${arch}-run1.sys" + cp "$run2" "${failure_dir}/${arch}-run2.sys" + echo "::error::${arch} build is not reproducible beyond signing metadata." + exit 1 + ;; + *) + echo "::error::Reproducibility comparison failed for ${arch} with exit code ${rc}" + exit "$rc" + ;; + esac + echo "::endgroup::" + + echo "::group::Compare ${arch} with tracked driver" + rc=0 + python3 .github/scripts/compare_signed_pe.py "$run2" "$tracked_file" || rc=$? + case $rc in + 0) + echo "${arch} matches the tracked driver after stripping Authenticode data." + ;; + 1|2) + echo "${arch} differs beyond signing metadata (rc=${rc}); replacing tracked driver." + cp "$run2" "$tracked_file" + git add -- "$tracked_file" + changed=true + ;; + *) + echo "::error::Comparison failed for ${arch} with exit code ${rc}" + exit "$rc" + ;; + esac + echo "::endgroup::" + done + + echo "changed=${changed}" >> "$GITHUB_OUTPUT" + + - name: Upload failed reproducibility artifacts + if: failure() + uses: actions/upload-artifact@v4 + with: + name: winredirect-repro-failure-${{ github.run_id }}-${{ github.run_attempt }} + path: ${{ runner.temp }}/winredirect-driver-sync/failure/*.sys + if-no-files-found: warn + + - name: Commit and push refreshed drivers + if: steps.sync.outputs.changed == 'true' + run: | + git config user.name 'github-actions[bot]' + git config user.email '41898282+github-actions[bot]@users.noreply.github.com' + git diff --cached --quiet && exit 0 + git commit -m 'winredirect: update bundled drivers' + git push origin "HEAD:${{ github.event.pull_request.head.ref }}" diff --git a/.gitignore b/.gitignore index 3f3c51ce..03b9ffb5 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ .DS_Store !/README.md /*.md +/.claude/ +/CLAUDE.md \ No newline at end of file diff --git a/internal/winipcfg/zwinipcfg_windows.go b/internal/winipcfg/zwinipcfg_windows.go index 3a0d8680..052134c5 100644 --- a/internal/winipcfg/zwinipcfg_windows.go +++ b/internal/winipcfg/zwinipcfg_windows.go @@ -74,7 +74,7 @@ var ( ) func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) { - r0, _, _ := syscall.Syscall(procCancelMibChangeNotify2.Addr(), 1, uintptr(notificationHandle), 0, 0) + r0, _, _ := syscall.SyscallN(procCancelMibChangeNotify2.Addr(), uintptr(notificationHandle)) if r0 != 0 { ret = syscall.Errno(r0) } @@ -82,7 +82,7 @@ func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) { } func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceGuidToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID)), 0) + r0, _, _ := syscall.SyscallN(procConvertInterfaceGuidToLuid.Addr(), uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -90,7 +90,7 @@ func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID } func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceIndexToLuid.Addr(), 2, uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)), 0) + r0, _, _ := syscall.SyscallN(procConvertInterfaceIndexToLuid.Addr(), uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -98,7 +98,7 @@ func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (re } func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0) + r0, _, _ := syscall.SyscallN(procConvertInterfaceLuidToGuid.Addr(), uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -106,7 +106,7 @@ func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID } func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procCreateAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procCreateAnycastIpAddressEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -114,7 +114,7 @@ func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { } func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) { - r0, _, _ := syscall.Syscall(procCreateIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + r0, _, _ := syscall.SyscallN(procCreateIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -122,7 +122,7 @@ func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) { } func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procCreateUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procCreateUnicastIpAddressEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -130,7 +130,7 @@ func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { } func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procDeleteAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procDeleteAnycastIpAddressEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -138,7 +138,7 @@ func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { } func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) { - r0, _, _ := syscall.Syscall(procDeleteIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + r0, _, _ := syscall.SyscallN(procDeleteIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -146,7 +146,7 @@ func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) { } func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procDeleteUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procDeleteUnicastIpAddressEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -154,12 +154,12 @@ func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { } func freeMibTable(memory unsafe.Pointer) { - syscall.Syscall(procFreeMibTable.Addr(), 1, uintptr(memory), 0, 0) + syscall.SyscallN(procFreeMibTable.Addr(), uintptr(memory)) return } func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procGetAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procGetAnycastIpAddressEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -167,7 +167,7 @@ func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { } func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) { - r0, _, _ := syscall.Syscall(procGetAnycastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) + r0, _, _ := syscall.SyscallN(procGetAnycastIpAddressTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(table))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -175,7 +175,7 @@ func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressT } func getIfEntry2(row *MibIfRow2) (ret error) { - r0, _, _ := syscall.Syscall(procGetIfEntry2.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procGetIfEntry2.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -183,7 +183,7 @@ func getIfEntry2(row *MibIfRow2) (ret error) { } func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) { - r0, _, _ := syscall.Syscall(procGetIfTable2Ex.Addr(), 2, uintptr(level), uintptr(unsafe.Pointer(table)), 0) + r0, _, _ := syscall.SyscallN(procGetIfTable2Ex.Addr(), uintptr(level), uintptr(unsafe.Pointer(table))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -191,7 +191,7 @@ func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) { } func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) { - r0, _, _ := syscall.Syscall(procGetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + r0, _, _ := syscall.SyscallN(procGetIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -199,7 +199,7 @@ func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) { } func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) { - r0, _, _ := syscall.Syscall(procGetIpForwardTable2.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) + r0, _, _ := syscall.SyscallN(procGetIpForwardTable2.Addr(), uintptr(family), uintptr(unsafe.Pointer(table))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -207,7 +207,7 @@ func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret e } func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { - r0, _, _ := syscall.Syscall(procGetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procGetIpInterfaceEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -215,7 +215,7 @@ func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { } func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) { - r0, _, _ := syscall.Syscall(procGetIpInterfaceTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) + r0, _, _ := syscall.SyscallN(procGetIpInterfaceTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(table))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -223,7 +223,7 @@ func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret } func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procGetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procGetUnicastIpAddressEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -231,7 +231,7 @@ func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { } func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) { - r0, _, _ := syscall.Syscall(procGetUnicastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) + r0, _, _ := syscall.SyscallN(procGetUnicastIpAddressTable.Addr(), uintptr(family), uintptr(unsafe.Pointer(table))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -239,17 +239,17 @@ func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressT } func initializeIPForwardEntry(route *MibIPforwardRow2) { - syscall.Syscall(procInitializeIpForwardEntry.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + syscall.SyscallN(procInitializeIpForwardEntry.Addr(), uintptr(unsafe.Pointer(route))) return } func initializeIPInterfaceEntry(row *MibIPInterfaceRow) { - syscall.Syscall(procInitializeIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + syscall.SyscallN(procInitializeIpInterfaceEntry.Addr(), uintptr(unsafe.Pointer(row))) return } func initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) { - syscall.Syscall(procInitializeUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + syscall.SyscallN(procInitializeUnicastIpAddressEntry.Addr(), uintptr(unsafe.Pointer(row))) return } @@ -258,7 +258,7 @@ func notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerConte if initialNotification { _p0 = 1 } - r0, _, _ := syscall.Syscall6(procNotifyIpInterfaceChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + r0, _, _ := syscall.SyscallN(procNotifyIpInterfaceChange.Addr(), uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -270,7 +270,7 @@ func notifyRouteChange2(family AddressFamily, callback uintptr, callerContext ui if initialNotification { _p0 = 1 } - r0, _, _ := syscall.Syscall6(procNotifyRouteChange2.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + r0, _, _ := syscall.SyscallN(procNotifyRouteChange2.Addr(), uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -282,19 +282,19 @@ func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, caller if initialNotification { _p0 = 1 } - r0, _, _ := syscall.Syscall6(procNotifyUnicastIpAddressChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + r0, _, _ := syscall.SyscallN(procNotifyUnicastIpAddressChange.Addr(), uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle))) if r0 != 0 { ret = syscall.Errno(r0) } return } -func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) { +func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) { ret = procSetInterfaceDnsSettings.Find() if ret != nil { return } - r0, _, _ := syscall.Syscall6(procSetInterfaceDnsSettings.Addr(), 5, uintptr(guid1), uintptr(guid2), uintptr(guid3), uintptr(guid4), uintptr(unsafe.Pointer(settings)), 0) + r0, _, _ := syscall.SyscallN(procSetInterfaceDnsSettings.Addr(), uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -306,19 +306,19 @@ func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *Dns if ret != nil { return } - r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings))) + r0, _, _ := syscall.SyscallN(procSetInterfaceDnsSettings.Addr(), uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings))) if r0 != 0 { ret = syscall.Errno(r0) } return } -func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) { +func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) { ret = procSetInterfaceDnsSettings.Find() if ret != nil { return } - r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0) + r0, _, _ := syscall.SyscallN(procSetInterfaceDnsSettings.Addr(), uintptr(guid1), uintptr(guid2), uintptr(guid3), uintptr(guid4), uintptr(unsafe.Pointer(settings))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -326,7 +326,7 @@ func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSett } func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) { - r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + r0, _, _ := syscall.SyscallN(procSetIpForwardEntry2.Addr(), uintptr(unsafe.Pointer(route))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -334,7 +334,7 @@ func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) { } func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { - r0, _, _ := syscall.Syscall(procSetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procSetIpInterfaceEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -342,7 +342,7 @@ func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { } func setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procSetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + r0, _, _ := syscall.SyscallN(procSetUnicastIpAddressEntry.Addr(), uintptr(unsafe.Pointer(row))) if r0 != 0 { ret = syscall.Errno(r0) } diff --git a/internal/winredirect/amd64/winredirect.sys b/internal/winredirect/amd64/winredirect.sys new file mode 100644 index 00000000..78395808 Binary files /dev/null and b/internal/winredirect/amd64/winredirect.sys differ diff --git a/internal/winredirect/arm/winredirect.sys b/internal/winredirect/arm/winredirect.sys new file mode 100644 index 00000000..311c8dd0 --- /dev/null +++ b/internal/winredirect/arm/winredirect.sys @@ -0,0 +1 @@ +PLACEHOLDER \ No newline at end of file diff --git a/internal/winredirect/arm64/winredirect.sys b/internal/winredirect/arm64/winredirect.sys new file mode 100644 index 00000000..f930e22b Binary files /dev/null and b/internal/winredirect/arm64/winredirect.sys differ diff --git a/internal/winredirect/driver/winredirect.c b/internal/winredirect/driver/winredirect.c new file mode 100644 index 00000000..dd5f36b2 --- /dev/null +++ b/internal/winredirect/driver/winredirect.c @@ -0,0 +1,1108 @@ + +#include "winredirect.h" + +// {E513903C-D2F3-4D8C-9458-0483E7D7A01F} +DEFINE_GUID(WINREDIRECT_PROVIDER_KEY, + 0xe513903c, 0xd2f3, 0x4d8c, 0x94, 0x58, 0x04, 0x83, 0xe7, 0xd7, 0xa0, 0x1f); + +// {8987A44E-ECB2-4A47-9FB6-B749C804FA3B} +DEFINE_GUID(WINREDIRECT_SUBLAYER_KEY, + 0x8987a44e, 0xecb2, 0x4a47, 0x9f, 0xb6, 0xb7, 0x49, 0xc8, 0x04, 0xfa, 0x3b); + +// {7EA20C4E-1A93-427E-80DC-E18A60AAB73B} +DEFINE_GUID(WINREDIRECT_CALLOUT_V4_KEY, + 0x7ea20c4e, 0x1a93, 0x427e, 0x80, 0xdc, 0xe1, 0x8a, 0x60, 0xaa, 0xb7, 0x3b); + +// {AABE8538-0A09-4D47-8E61-1127CE5BB1AB} +DEFINE_GUID(WINREDIRECT_CALLOUT_V6_KEY, + 0xaabe8538, 0x0a09, 0x4d47, 0x8e, 0x61, 0x11, 0x27, 0xce, 0x5b, 0xb1, 0xab); + +static PDRIVER_CONTEXT g_Ctx = NULL; + +#define PENDING_QUEUED_TIMEOUT_SECONDS 5 +#define PENDING_DELIVERED_TIMEOUT_SECONDS 15 + +static void PermitClassify(_Inout_ FWPS_CLASSIFY_OUT0* classifyOut) +{ + classifyOut->actionType = FWP_ACTION_PERMIT; + classifyOut->rights &= ~FWPS_RIGHT_ACTION_WRITE; +} + +static void BlockClassify(_Inout_ FWPS_CLASSIFY_OUT0* classifyOut) +{ + classifyOut->actionType = FWP_ACTION_BLOCK; + classifyOut->rights &= ~FWPS_RIGHT_ACTION_WRITE; +} + +static NTSTATUS ReadFatalStatus(_In_ PDRIVER_CONTEXT Ctx) +{ + return (NTSTATUS)InterlockedCompareExchange(&Ctx->FatalStatus, STATUS_SUCCESS, STATUS_SUCCESS); +} + +static NTSTATUS NormalizeFatalStatus(_In_ NTSTATUS Status) +{ + if (NT_SUCCESS(Status)) { + return STATUS_DRIVER_INTERNAL_ERROR; + } + return Status; +} + +static NTSTATUS TriggerFatal(_In_ PDRIVER_CONTEXT Ctx, _In_ NTSTATUS Status, _In_ const char* Message) +{ + NTSTATUS normalized = NormalizeFatalStatus(Status); + NTSTATUS previous = (NTSTATUS)InterlockedCompareExchange(&Ctx->FatalStatus, normalized, STATUS_SUCCESS); + + if (previous == STATUS_SUCCESS) { + RtlStringCbCopyA(Ctx->FatalMessage, sizeof(Ctx->FatalMessage), Message); + WdfWorkItemEnqueue(Ctx->FatalWorkItem); + return normalized; + } + + return previous; +} + +static void TriggerFatalAndPermitClassify( + _In_ PDRIVER_CONTEXT Ctx, + _Inout_ FWPS_CLASSIFY_OUT0* classifyOut, + _In_ NTSTATUS Status, + _In_ const char* Message) +{ + if (Ctx) { + TriggerFatal(Ctx, Status, Message); + } + PermitClassify(classifyOut); +} + +static BOOLEAN IsLoopbackAddress(_In_ UINT8 AddressFamily, _In_reads_(16) const UINT8* Address) +{ + if (AddressFamily == AF_INET) { + return Address[0] == 127; + } + + if (AddressFamily == AF_INET6) { + for (UINT32 i = 0; i < 15; i++) { + if (Address[i] != 0) { + return FALSE; + } + } + return Address[15] == 1; + } + + return FALSE; +} + +static BOOLEAN IsAnyAddress(_In_ UINT8 AddressFamily, _In_reads_(16) const UINT8* Address) +{ + if (AddressFamily == AF_INET) { + return Address[0] == 0 && Address[1] == 0 && Address[2] == 0 && Address[3] == 0; + } + + if (AddressFamily == AF_INET6) { + for (UINT32 i = 0; i < 16; i++) { + if (Address[i] != 0) { + return FALSE; + } + } + return TRUE; + } + + return FALSE; +} + +typedef enum _BEST_ROUTE_RESULT { + BestRouteTun = 1, + BestRouteOther = 2, +} BEST_ROUTE_RESULT; + +static CONFIG_SNAPSHOT ReadConfigSnapshot(_In_ PDRIVER_CONTEXT Ctx) +{ + CONFIG_SNAPSHOT snapshot; + KIRQL oldIrql; + + RtlZeroMemory(&snapshot, sizeof(snapshot)); + KeAcquireSpinLock(&Ctx->ConfigLock, &oldIrql); + snapshot.Config = Ctx->Config; + snapshot.TunLuid = Ctx->TunLuid; + snapshot.HasTunLuid = Ctx->HasTunLuid; + KeReleaseSpinLock(&Ctx->ConfigLock, oldIrql); + + return snapshot; +} + +static BOOLEAN TryBestRouteForEntry( + _In_ const CONFIG_SNAPSHOT* Snapshot, + _In_ const PENDING_ENTRY* Entry, + _Out_ BEST_ROUTE_RESULT* Result) +{ + SOCKADDR_INET sourceAddress; + SOCKADDR_INET destinationAddress; + SOCKADDR_INET bestSourceAddress; + SOCKADDR_INET* sourceAddressPtr = NULL; + MIB_IPFORWARD_ROW2 bestRoute; + NETIO_STATUS status; + + if (!Snapshot->HasTunLuid) { + return FALSE; + } + // GetBestRoute2 requires IRQL < DISPATCH_LEVEL. We do not currently + // characterize every runtime context where route lookup can be unavailable, + // so report a normal lookup failure and let the caller decide the fallback. + if (KeGetCurrentIrql() >= DISPATCH_LEVEL) { + return FALSE; + } + + RtlZeroMemory(&sourceAddress, sizeof(sourceAddress)); + RtlZeroMemory(&destinationAddress, sizeof(destinationAddress)); + RtlZeroMemory(&bestSourceAddress, sizeof(bestSourceAddress)); + RtlZeroMemory(&bestRoute, sizeof(bestRoute)); + + if (Entry->AddressFamily == AF_INET) { + destinationAddress.Ipv4.sin_family = AF_INET; + RtlCopyMemory(&destinationAddress.Ipv4.sin_addr, Entry->DstAddr, sizeof(destinationAddress.Ipv4.sin_addr)); + if (!IsAnyAddress(AF_INET, Entry->SrcAddr)) { + sourceAddress.Ipv4.sin_family = AF_INET; + RtlCopyMemory(&sourceAddress.Ipv4.sin_addr, Entry->SrcAddr, sizeof(sourceAddress.Ipv4.sin_addr)); + sourceAddressPtr = &sourceAddress; + } + } else if (Entry->AddressFamily == AF_INET6) { + destinationAddress.Ipv6.sin6_family = AF_INET6; + RtlCopyMemory(destinationAddress.Ipv6.sin6_addr.u.Byte, Entry->DstAddr, sizeof(destinationAddress.Ipv6.sin6_addr.u.Byte)); + if (!IsAnyAddress(AF_INET6, Entry->SrcAddr)) { + sourceAddress.Ipv6.sin6_family = AF_INET6; + RtlCopyMemory(sourceAddress.Ipv6.sin6_addr.u.Byte, Entry->SrcAddr, sizeof(sourceAddress.Ipv6.sin6_addr.u.Byte)); + sourceAddressPtr = &sourceAddress; + } + } else { + return FALSE; + } + + status = GetBestRoute2(NULL, 0, sourceAddressPtr, &destinationAddress, 0, &bestRoute, &bestSourceAddress); + if (status != 0) { + return FALSE; + } + if (bestRoute.InterfaceLuid.Value == Snapshot->TunLuid.Value) { + *Result = BestRouteTun; + return TRUE; + } + *Result = BestRouteOther; + return TRUE; +} + +static void CancelPendingIoctlRequests(_In_ PDRIVER_CONTEXT Ctx, _In_ NTSTATUS Status) +{ + WDFREQUEST request; + + while (NT_SUCCESS(WdfIoQueueRetrieveNextRequest(Ctx->PendingIoctlQueue, &request))) { + WdfRequestComplete(request, Status); + } +} + +static void ShutdownRedirect(_In_ PDRIVER_CONTEXT Ctx, _In_ UINT32 PendingVerdict, _In_ NTSTATUS RequestStatus) +{ + if (InterlockedCompareExchange(&Ctx->Running, FALSE, TRUE) == TRUE) { + WdfTimerStop(Ctx->TimeoutTimer, TRUE); + WdfWorkItemFlush(Ctx->TimeoutWorkItem); + WfpCleanup(Ctx); + WdfWorkItemFlush(Ctx->PendingDeliveryWorkItem); + } + + PendingFlushAll(Ctx, PendingVerdict); + CancelPendingIoctlRequests(Ctx, RequestStatus); +} + +static PPENDING_ENTRY PendingReserveNextQueued(_In_ PDRIVER_CONTEXT Ctx) +{ + PPENDING_ENTRY found = NULL; + KIRQL oldIrql; + + KeAcquireSpinLock(&Ctx->PendingLock, &oldIrql); + PLIST_ENTRY entry = Ctx->PendingList.Flink; + while (entry != &Ctx->PendingList) { + PPENDING_ENTRY pending = CONTAINING_RECORD(entry, PENDING_ENTRY, ListEntry); + if (pending->DeliveryState == PendingDeliveryQueued) { + pending->DeliveryState = PendingDeliveryCopying; + found = pending; + break; + } + entry = entry->Flink; + } + KeReleaseSpinLock(&Ctx->PendingLock, oldIrql); + + return found; +} + +static void PendingSetDeliveryState( + _In_ PDRIVER_CONTEXT Ctx, + _In_ PPENDING_ENTRY Entry, + _In_ LONG State, + _In_opt_ const LARGE_INTEGER* Timestamp) +{ + KIRQL oldIrql; + + KeAcquireSpinLock(&Ctx->PendingLock, &oldIrql); + Entry->DeliveryState = State; + if (Timestamp) { + Entry->Timestamp = *Timestamp; + } + KeReleaseSpinLock(&Ctx->PendingLock, oldIrql); +} + +static void TryCompletePendingRequests(_In_ PDRIVER_CONTEXT Ctx) +{ + if (ReadFatalStatus(Ctx) != STATUS_SUCCESS) { + return; + } + + for (;;) { + PPENDING_ENTRY pending = PendingReserveNextQueued(Ctx); + if (!pending) { + break; + } + + WDFREQUEST request; + NTSTATUS status = WdfIoQueueRetrieveNextRequest(Ctx->PendingIoctlQueue, &request); + if (!NT_SUCCESS(status)) { + PendingSetDeliveryState(Ctx, pending, PendingDeliveryQueued, NULL); + break; + } + + PVOID outBuf; + status = WdfRequestRetrieveOutputBuffer(request, sizeof(WINREDIRECT_PENDING_CONN), &outBuf, NULL); + if (!NT_SUCCESS(status)) { + PendingSetDeliveryState(Ctx, pending, PendingDeliveryQueued, NULL); + WdfRequestComplete(request, status); + continue; + } + + WINREDIRECT_PENDING_CONN* out = (WINREDIRECT_PENDING_CONN*)outBuf; + RtlZeroMemory(out, sizeof(*out)); + out->ConnID = pending->ConnID; + out->AddressFamily = pending->AddressFamily; + RtlCopyMemory(out->SrcAddr, pending->SrcAddr, 16); + out->SrcPort = pending->SrcPort; + RtlCopyMemory(out->DstAddr, pending->DstAddr, 16); + out->DstPort = pending->DstPort; + out->ProcessID = pending->ProcessID; + LARGE_INTEGER now; + KeQuerySystemTime(&now); + PendingSetDeliveryState(Ctx, pending, PendingDeliveryDelivered, &now); + WdfRequestCompleteWithInformation(request, STATUS_SUCCESS, sizeof(WINREDIRECT_PENDING_CONN)); + } +} + +NTSTATUS DriverEntry(_In_ PDRIVER_OBJECT DriverObject, _In_ PUNICODE_STRING RegistryPath) +{ + NTSTATUS status; + WDF_DRIVER_CONFIG driverConfig; + WDF_OBJECT_ATTRIBUTES driverAttrs; + WDFDRIVER driver; + WDFDEVICE device; + PWDFDEVICE_INIT deviceInit; + WDF_OBJECT_ATTRIBUTES deviceAttrs; + UNICODE_STRING deviceName = RTL_CONSTANT_STRING(DEVICE_NAME); + UNICODE_STRING symlinkName = RTL_CONSTANT_STRING(SYMLINK_NAME); + PDRIVER_CONTEXT ctx; + + WDF_DRIVER_CONFIG_INIT(&driverConfig, WDF_NO_EVENT_CALLBACK); + driverConfig.DriverInitFlags = WdfDriverInitNonPnpDriver; + driverConfig.EvtDriverUnload = EvtDriverUnload; + + WDF_OBJECT_ATTRIBUTES_INIT(&driverAttrs); + status = WdfDriverCreate(DriverObject, RegistryPath, &driverAttrs, &driverConfig, &driver); + if (!NT_SUCCESS(status)) return status; + + deviceInit = WdfControlDeviceInitAllocate(driver, &SDDL_DEVOBJ_SYS_ALL_ADM_ALL); + if (!deviceInit) return STATUS_INSUFFICIENT_RESOURCES; + + WdfDeviceInitSetDeviceType(deviceInit, FILE_DEVICE_NETWORK); + WdfDeviceInitSetCharacteristics(deviceInit, FILE_DEVICE_SECURE_OPEN, FALSE); + + status = WdfDeviceInitAssignName(deviceInit, &deviceName); + if (!NT_SUCCESS(status)) { WdfDeviceInitFree(deviceInit); return status; } + + WDF_OBJECT_ATTRIBUTES_INIT_CONTEXT_TYPE(&deviceAttrs, DRIVER_CONTEXT); + deviceAttrs.ExecutionLevel = WdfExecutionLevelPassive; + status = WdfDeviceCreate(&deviceInit, &deviceAttrs, &device); + if (!NT_SUCCESS(status)) return status; + + status = WdfDeviceCreateSymbolicLink(device, &symlinkName); + if (!NT_SUCCESS(status)) return status; + + ctx = GetDriverContext(device); + RtlZeroMemory(ctx, sizeof(DRIVER_CONTEXT)); + ctx->Device = device; + InitializeListHead(&ctx->PendingList); + KeInitializeSpinLock(&ctx->PendingLock); + KeInitializeSpinLock(&ctx->ConfigLock); + g_Ctx = ctx; + + // Create manual-dispatch queue for pending IOCTLs + WDF_IO_QUEUE_CONFIG queueConfig; + WDF_IO_QUEUE_CONFIG_INIT(&queueConfig, WdfIoQueueDispatchManual); + queueConfig.EvtIoCanceledOnQueue = EvtIoCanceledOnQueue; + status = WdfIoQueueCreate(device, &queueConfig, WDF_NO_OBJECT_ATTRIBUTES, &ctx->PendingIoctlQueue); + if (!NT_SUCCESS(status)) return status; + + // Create default queue for all other IOCTLs + WDF_IO_QUEUE_CONFIG defaultQueueConfig; + WDF_IO_QUEUE_CONFIG_INIT_DEFAULT_QUEUE(&defaultQueueConfig, WdfIoQueueDispatchParallel); + defaultQueueConfig.EvtIoDeviceControl = EvtIoDeviceControl; + status = WdfIoQueueCreate(device, &defaultQueueConfig, WDF_NO_OBJECT_ATTRIBUTES, NULL); + if (!NT_SUCCESS(status)) return status; + + // Create timeout timer (sweeps stale pending entries every 5 seconds) + WDF_TIMER_CONFIG timerConfig; + WDF_TIMER_CONFIG_INIT_PERIODIC(&timerConfig, EvtTimeoutTimer, 5000); + WDF_OBJECT_ATTRIBUTES timerAttrs; + WDF_OBJECT_ATTRIBUTES_INIT(&timerAttrs); + timerAttrs.ParentObject = device; + status = WdfTimerCreate(&timerConfig, &timerAttrs, &ctx->TimeoutTimer); + if (!NT_SUCCESS(status)) return status; + + // Create work item for timeout processing at PASSIVE_LEVEL + WDF_WORKITEM_CONFIG workItemConfig; + WDF_WORKITEM_CONFIG_INIT(&workItemConfig, EvtTimeoutWorkItem); + WDF_OBJECT_ATTRIBUTES workItemAttrs; + WDF_OBJECT_ATTRIBUTES_INIT(&workItemAttrs); + workItemAttrs.ParentObject = device; + status = WdfWorkItemCreate(&workItemConfig, &workItemAttrs, &ctx->TimeoutWorkItem); + if (!NT_SUCCESS(status)) return status; + + WDF_WORKITEM_CONFIG pendingDeliveryConfig; + WDF_WORKITEM_CONFIG_INIT(&pendingDeliveryConfig, EvtPendingDeliveryWorkItem); + WDF_OBJECT_ATTRIBUTES pendingDeliveryAttrs; + WDF_OBJECT_ATTRIBUTES_INIT(&pendingDeliveryAttrs); + pendingDeliveryAttrs.ParentObject = device; + status = WdfWorkItemCreate(&pendingDeliveryConfig, &pendingDeliveryAttrs, &ctx->PendingDeliveryWorkItem); + if (!NT_SUCCESS(status)) return status; + + WDF_WORKITEM_CONFIG fatalConfig; + WDF_WORKITEM_CONFIG_INIT(&fatalConfig, EvtFatalWorkItem); + WDF_OBJECT_ATTRIBUTES fatalAttrs; + WDF_OBJECT_ATTRIBUTES_INIT(&fatalAttrs); + fatalAttrs.ParentObject = device; + status = WdfWorkItemCreate(&fatalConfig, &fatalAttrs, &ctx->FatalWorkItem); + if (!NT_SUCCESS(status)) return status; + + WdfControlFinishInitializing(device); + return STATUS_SUCCESS; +} + +void EvtDriverUnload(_In_ WDFDRIVER Driver) +{ + NTSTATUS fatalStatus; + + UNREFERENCED_PARAMETER(Driver); + if (g_Ctx) { + WdfWorkItemFlush(g_Ctx->FatalWorkItem); + fatalStatus = ReadFatalStatus(g_Ctx); + ShutdownRedirect( + g_Ctx, + VERDICT_PERMIT, + fatalStatus != STATUS_SUCCESS ? fatalStatus : STATUS_CANCELLED); + } +} + +void EvtIoDeviceControl( + _In_ WDFQUEUE Queue, + _In_ WDFREQUEST Request, + _In_ size_t OutputBufferLength, + _In_ size_t InputBufferLength, + _In_ ULONG IoControlCode) +{ + UNREFERENCED_PARAMETER(Queue); + UNREFERENCED_PARAMETER(OutputBufferLength); + UNREFERENCED_PARAMETER(InputBufferLength); + PDRIVER_CONTEXT ctx = g_Ctx; + NTSTATUS status = STATUS_SUCCESS; + NTSTATUS fatalStatus = STATUS_SUCCESS; + PVOID inBuf = NULL; + size_t inLen = 0; + + if (!ctx) { + WdfRequestComplete(Request, STATUS_INVALID_DEVICE_STATE); + return; + } + + fatalStatus = ReadFatalStatus(ctx); + + switch (IoControlCode) { + case IOCTL_WINREDIRECT_SET_CONFIG: + if (fatalStatus != STATUS_SUCCESS) { + WdfRequestComplete(Request, fatalStatus); + break; + } + if (ctx->Running) { + WdfRequestComplete(Request, STATUS_DEVICE_BUSY); + break; + } + status = WdfRequestRetrieveInputBuffer(Request, sizeof(WINREDIRECT_CONFIG), &inBuf, &inLen); + if (NT_SUCCESS(status)) { + WINREDIRECT_CONFIG* config = (WINREDIRECT_CONFIG*)inBuf; + NET_LUID tunLuid = {0}; + const GUID nullGuid = {0}; + KIRQL oldIrql; + if (config->RedirectPort == 0 || config->ProxyPID == 0 || InlineIsEqualGUID(&nullGuid, &config->TunGuid)) { + status = STATUS_INVALID_PARAMETER; + } else { + status = ConvertInterfaceGuidToLuid(&config->TunGuid, &tunLuid); + } + if (NT_SUCCESS(status)) { + KeAcquireSpinLock(&ctx->ConfigLock, &oldIrql); + RtlCopyMemory(&ctx->Config, config, sizeof(WINREDIRECT_CONFIG)); + ctx->TunLuid = tunLuid; + ctx->HasTunLuid = TRUE; + KeReleaseSpinLock(&ctx->ConfigLock, oldIrql); + } + } + WdfRequestComplete(Request, status); + break; + + case IOCTL_WINREDIRECT_START: { + CONFIG_SNAPSHOT snapshot; + if (fatalStatus != STATUS_SUCCESS) { + WdfRequestComplete(Request, fatalStatus); + break; + } + if (InterlockedCompareExchange(&ctx->Running, TRUE, FALSE) != FALSE) { + WdfRequestComplete(Request, STATUS_ALREADY_REGISTERED); + break; + } + snapshot = ReadConfigSnapshot(ctx); + if (!snapshot.HasTunLuid || snapshot.Config.RedirectPort == 0 || snapshot.Config.ProxyPID == 0) { + InterlockedExchange(&ctx->Running, FALSE); + WdfRequestComplete(Request, STATUS_INVALID_DEVICE_STATE); + break; + } + status = WfpSetup(ctx); + if (NT_SUCCESS(status)) { + WdfTimerStart(ctx->TimeoutTimer, WDF_REL_TIMEOUT_IN_SEC(5)); + } else { + InterlockedExchange(&ctx->Running, FALSE); + } + WdfRequestComplete(Request, status); + break; + } + + case IOCTL_WINREDIRECT_STOP: + if (fatalStatus != STATUS_SUCCESS) { + ShutdownRedirect(ctx, VERDICT_PERMIT, fatalStatus); + } else { + ShutdownRedirect(ctx, VERDICT_PERMIT, STATUS_CANCELLED); + } + WdfRequestComplete(Request, STATUS_SUCCESS); + break; + + case IOCTL_WINREDIRECT_GET_PENDING: + if (fatalStatus != STATUS_SUCCESS) { + WdfRequestComplete(Request, fatalStatus); + break; + } + if (!ctx->Running) { + WdfRequestComplete(Request, STATUS_DEVICE_NOT_READY); + break; + } + // Forward to manual queue - will be completed when a connection arrives + status = WdfRequestForwardToIoQueue(Request, ctx->PendingIoctlQueue); + if (!NT_SUCCESS(status)) { + WdfRequestComplete(Request, status); + } else { + WdfWorkItemEnqueue(ctx->PendingDeliveryWorkItem); + } + break; + + case IOCTL_WINREDIRECT_SET_VERDICT: { + if (fatalStatus != STATUS_SUCCESS) { + WdfRequestComplete(Request, fatalStatus); + break; + } + if (!ctx->Running) { + WdfRequestComplete(Request, STATUS_DEVICE_NOT_READY); + break; + } + status = WdfRequestRetrieveInputBuffer(Request, sizeof(WINREDIRECT_VERDICT), &inBuf, &inLen); + if (!NT_SUCCESS(status)) { + WdfRequestComplete(Request, status); + break; + } + WINREDIRECT_VERDICT* v = (WINREDIRECT_VERDICT*)inBuf; + if (v->Verdict != VERDICT_REDIRECT && v->Verdict != VERDICT_PERMIT) { + WdfRequestComplete(Request, STATUS_INVALID_PARAMETER); + break; + } + PPENDING_ENTRY entry = PendingFindByID(ctx, v->ConnID); + if (entry) { + ExecuteVerdict(ctx, entry, v->Verdict); + ExFreePoolWithTag(entry, 'rniW'); + } + WdfRequestComplete(Request, entry ? STATUS_SUCCESS : STATUS_NOT_FOUND); + break; + } + + case IOCTL_WINREDIRECT_GET_FATAL_INFO: { + PVOID outBuf; + status = WdfRequestRetrieveOutputBuffer(Request, sizeof(WINREDIRECT_FATAL_INFO), &outBuf, NULL); + if (!NT_SUCCESS(status)) { + WdfRequestComplete(Request, status); + break; + } + WINREDIRECT_FATAL_INFO* info = (WINREDIRECT_FATAL_INFO*)outBuf; + info->Status = (UINT32)ReadFatalStatus(ctx); + RtlStringCbCopyA(info->Message, sizeof(info->Message), ctx->FatalMessage); + WdfRequestCompleteWithInformation(Request, STATUS_SUCCESS, sizeof(WINREDIRECT_FATAL_INFO)); + break; + } + + default: + WdfRequestComplete(Request, STATUS_INVALID_DEVICE_REQUEST); + break; + } +} + +void EvtIoCanceledOnQueue(_In_ WDFQUEUE Queue, _In_ WDFREQUEST Request) +{ + UNREFERENCED_PARAMETER(Queue); + WdfRequestComplete(Request, STATUS_CANCELLED); +} + +void EvtTimeoutTimer(_In_ WDFTIMER Timer) +{ + UNREFERENCED_PARAMETER(Timer); + if (!g_Ctx) return; + WdfWorkItemEnqueue(g_Ctx->TimeoutWorkItem); +} + +void EvtTimeoutWorkItem(_In_ WDFWORKITEM WorkItem) +{ + UNREFERENCED_PARAMETER(WorkItem); + if (!g_Ctx) return; + if (ReadFatalStatus(g_Ctx) != STATUS_SUCCESS) return; + + for (;;) { + LARGE_INTEGER now; + PPENDING_ENTRY expired = NULL; + KIRQL oldIrql; + + KeQuerySystemTime(&now); + KeAcquireSpinLock(&g_Ctx->PendingLock, &oldIrql); + + PLIST_ENTRY entry = g_Ctx->PendingList.Flink; + while (entry != &g_Ctx->PendingList) { + PPENDING_ENTRY pending = CONTAINING_RECORD(entry, PENDING_ENTRY, ListEntry); + LONGLONG timeoutSeconds = 0; + entry = entry->Flink; + + if (pending->DeliveryState == PendingDeliveryQueued) { + timeoutSeconds = PENDING_QUEUED_TIMEOUT_SECONDS; + } else if (pending->DeliveryState == PendingDeliveryDelivered) { + timeoutSeconds = PENDING_DELIVERED_TIMEOUT_SECONDS; + } else { + continue; + } + + LONGLONG elapsed = (now.QuadPart - pending->Timestamp.QuadPart) / 10000000LL; + if (elapsed >= timeoutSeconds) { + RemoveEntryList(&pending->ListEntry); + expired = pending; + break; + } + } + + KeReleaseSpinLock(&g_Ctx->PendingLock, oldIrql); + + if (!expired) { + break; + } + + ExecuteVerdict(g_Ctx, expired, VERDICT_PERMIT); + ExFreePoolWithTag(expired, 'rniW'); + } +} + +void EvtPendingDeliveryWorkItem(_In_ WDFWORKITEM WorkItem) +{ + UNREFERENCED_PARAMETER(WorkItem); + if (!g_Ctx || !g_Ctx->Running || ReadFatalStatus(g_Ctx) != STATUS_SUCCESS) return; + TryCompletePendingRequests(g_Ctx); +} + +void EvtFatalWorkItem(_In_ WDFWORKITEM WorkItem) +{ + NTSTATUS fatalStatus; + + UNREFERENCED_PARAMETER(WorkItem); + if (!g_Ctx) return; + + fatalStatus = ReadFatalStatus(g_Ctx); + if (fatalStatus == STATUS_SUCCESS) { + return; + } + + ShutdownRedirect(g_Ctx, VERDICT_PERMIT, fatalStatus); +} + +// --- WFP Setup --- + +NTSTATUS WfpSetup(_In_ PDRIVER_CONTEXT Ctx) +{ + NTSTATUS status; + FWPM_SESSION0 session = { .flags = FWPM_SESSION_FLAG_DYNAMIC }; + + status = FwpmEngineOpen0(NULL, RPC_C_AUTHN_DEFAULT, NULL, &session, &Ctx->EngineHandle); + if (!NT_SUCCESS(status)) return status; + + FWPM_SUBLAYER0 subLayer = { + .subLayerKey = WINREDIRECT_SUBLAYER_KEY, + .displayData = { .name = L"WinRedirect SubLayer" }, + .weight = MAXUINT16, + }; + status = FwpmSubLayerAdd0(Ctx->EngineHandle, &subLayer, NULL); + if (!NT_SUCCESS(status)) goto cleanup; + + // Create redirect handle + status = FwpsRedirectHandleCreate0(&WINREDIRECT_PROVIDER_KEY, 0, &Ctx->RedirectHandle); + if (!NT_SUCCESS(status)) goto cleanup; + + // Register callouts + FWPS_CALLOUT1 sCalloutV4 = { + .calloutKey = WINREDIRECT_CALLOUT_V4_KEY, + .classifyFn = ClassifyFnV4, + .notifyFn = NotifyFn, + }; + status = FwpsCalloutRegister1(WdfDeviceWdmGetDeviceObject(Ctx->Device), &sCalloutV4, &Ctx->CalloutIdV4); + if (!NT_SUCCESS(status)) goto cleanup; + + FWPS_CALLOUT1 sCalloutV6 = { + .calloutKey = WINREDIRECT_CALLOUT_V6_KEY, + .classifyFn = ClassifyFnV6, + .notifyFn = NotifyFn, + }; + status = FwpsCalloutRegister1(WdfDeviceWdmGetDeviceObject(Ctx->Device), &sCalloutV6, &Ctx->CalloutIdV6); + if (!NT_SUCCESS(status)) goto cleanup; + + // Add callouts to BFE + FWPM_CALLOUT0 mCalloutV4 = { + .calloutKey = WINREDIRECT_CALLOUT_V4_KEY, + .displayData = { .name = L"WinRedirect V4 Callout" }, + .applicableLayer = FWPM_LAYER_ALE_CONNECT_REDIRECT_V4, + }; + status = FwpmCalloutAdd0(Ctx->EngineHandle, &mCalloutV4, NULL, NULL); + if (!NT_SUCCESS(status)) goto cleanup; + + FWPM_CALLOUT0 mCalloutV6 = { + .calloutKey = WINREDIRECT_CALLOUT_V6_KEY, + .displayData = { .name = L"WinRedirect V6 Callout" }, + .applicableLayer = FWPM_LAYER_ALE_CONNECT_REDIRECT_V6, + }; + status = FwpmCalloutAdd0(Ctx->EngineHandle, &mCalloutV6, NULL, NULL); + if (!NT_SUCCESS(status)) goto cleanup; + + // Add filters - condition: TCP only + FWPM_FILTER_CONDITION0 tcpCondition = { + .fieldKey = FWPM_CONDITION_IP_PROTOCOL, + .matchType = FWP_MATCH_EQUAL, + .conditionValue = { .type = FWP_UINT8, .uint8 = IPPROTO_TCP }, + }; + + FWPM_FILTER0 filterV4 = { + .displayData = { .name = L"WinRedirect V4 Filter" }, + .layerKey = FWPM_LAYER_ALE_CONNECT_REDIRECT_V4, + .subLayerKey = WINREDIRECT_SUBLAYER_KEY, + .action = { .type = FWP_ACTION_CALLOUT_TERMINATING, .calloutKey = WINREDIRECT_CALLOUT_V4_KEY }, + .weight = { .type = FWP_UINT8, .uint8 = 15 }, + .numFilterConditions = 1, + .filterCondition = &tcpCondition, + }; + status = FwpmFilterAdd0(Ctx->EngineHandle, &filterV4, NULL, &Ctx->FilterIdV4); + if (!NT_SUCCESS(status)) goto cleanup; + + FWPM_FILTER0 filterV6 = { + .displayData = { .name = L"WinRedirect V6 Filter" }, + .layerKey = FWPM_LAYER_ALE_CONNECT_REDIRECT_V6, + .subLayerKey = WINREDIRECT_SUBLAYER_KEY, + .action = { .type = FWP_ACTION_CALLOUT_TERMINATING, .calloutKey = WINREDIRECT_CALLOUT_V6_KEY }, + .weight = { .type = FWP_UINT8, .uint8 = 15 }, + .numFilterConditions = 1, + .filterCondition = &tcpCondition, + }; + status = FwpmFilterAdd0(Ctx->EngineHandle, &filterV6, NULL, &Ctx->FilterIdV6); + if (!NT_SUCCESS(status)) goto cleanup; + + return STATUS_SUCCESS; + +cleanup: + WfpCleanup(Ctx); + return status; +} + +void WfpCleanup(_In_ PDRIVER_CONTEXT Ctx) +{ + if (Ctx->CalloutIdV4) { + FwpsCalloutUnregisterById0(Ctx->CalloutIdV4); + Ctx->CalloutIdV4 = 0; + } + if (Ctx->CalloutIdV6) { + FwpsCalloutUnregisterById0(Ctx->CalloutIdV6); + Ctx->CalloutIdV6 = 0; + } + if (Ctx->RedirectHandle) { + FwpsRedirectHandleDestroy0(Ctx->RedirectHandle); + Ctx->RedirectHandle = NULL; + } + if (Ctx->EngineHandle) { + FwpmEngineClose0(Ctx->EngineHandle); + Ctx->EngineHandle = NULL; + } +} + +NTSTATUS NTAPI NotifyFn( + _In_ FWPS_CALLOUT_NOTIFY_TYPE notifyType, + _In_ const GUID* filterKey, + _Inout_ FWPS_FILTER1* filter) +{ + UNREFERENCED_PARAMETER(notifyType); + UNREFERENCED_PARAMETER(filterKey); + UNREFERENCED_PARAMETER(filter); + return STATUS_SUCCESS; +} + +// --- Classify callbacks --- + +static void ClassifyFnCommon( + _In_ UINT8 addressFamily, + _In_ const FWPS_INCOMING_VALUES0* inFixedValues, + _In_ const FWPS_INCOMING_METADATA_VALUES0* inMetaValues, + _Inout_opt_ void* layerData, + _In_opt_ const void* classifyContext, + _In_ const FWPS_FILTER1* filter, + _In_ UINT64 flowContext, + _Inout_ FWPS_CLASSIFY_OUT0* classifyOut, + _In_ UINT32 localAddrIdx, + _In_ UINT32 localPortIdx, + _In_ UINT32 remoteAddrIdx, + _In_ UINT32 remotePortIdx) +{ + PDRIVER_CONTEXT ctx = g_Ctx; + NTSTATUS fatalStatus; + NTSTATUS status; + CONFIG_SNAPSHOT snapshot; + PPENDING_ENTRY entry; + BEST_ROUTE_RESULT bestRoute; + UINT64 classifyHandle; + + UNREFERENCED_PARAMETER(layerData); + UNREFERENCED_PARAMETER(flowContext); + + if (!ctx || !ctx->Running) { + PermitClassify(classifyOut); + return; + } + + fatalStatus = ReadFatalStatus(ctx); + if (fatalStatus != STATUS_SUCCESS) { + PermitClassify(classifyOut); + return; + } + + // Must have write rights to modify the classify decision + if (!(classifyOut->rights & FWPS_RIGHT_ACTION_WRITE)) { + return; + } + + snapshot = ReadConfigSnapshot(ctx); + +#if (NTDDI_VERSION >= NTDDI_WIN8) + if (ctx->RedirectHandle && + FWPS_IS_METADATA_FIELD_PRESENT(inMetaValues, FWPS_METADATA_FIELD_REDIRECT_RECORD_HANDLE)) { + FWPS_CONNECTION_REDIRECT_STATE redirectState = + FwpsQueryConnectionRedirectState0(inMetaValues->redirectRecords, ctx->RedirectHandle, NULL); + switch (redirectState) { + case FWPS_CONNECTION_REDIRECTED_BY_SELF: + case FWPS_CONNECTION_PREVIOUSLY_REDIRECTED_BY_SELF: + PermitClassify(classifyOut); + return; + case FWPS_CONNECTION_NOT_REDIRECTED: + case FWPS_CONNECTION_REDIRECTED_BY_OTHER: + default: + break; + } + } +#endif + + if (snapshot.Config.ProxyPID != 0 && + FWPS_IS_METADATA_FIELD_PRESENT(inMetaValues, FWPS_METADATA_FIELD_PROCESS_ID) && + (UINT32)inMetaValues->processId == snapshot.Config.ProxyPID) { + PermitClassify(classifyOut); + return; + } + + // Allocate pending entry + entry = (PPENDING_ENTRY)ExAllocatePoolZero(NonPagedPoolNx, sizeof(PENDING_ENTRY), 'rniW'); + if (!entry) { + TriggerFatalAndPermitClassify(ctx, classifyOut, STATUS_INSUFFICIENT_RESOURCES, "allocate pending entry"); + return; + } + entry->ConnID = InterlockedIncrement64(&ctx->NextConnID); + entry->AddressFamily = addressFamily; + entry->FilterId = filter->filterId; + + // Extract addresses with NULL checks for IPv6 pointers + if (addressFamily == AF_INET) { + UINT32 srcIp = inFixedValues->incomingValue[localAddrIdx].value.uint32; + UINT32 dstIp = inFixedValues->incomingValue[remoteAddrIdx].value.uint32; + // WFP stores IPv4 in host byte order + *(UINT32*)entry->SrcAddr = RtlUlongByteSwap(srcIp); + *(UINT32*)entry->DstAddr = RtlUlongByteSwap(dstIp); + } else { + FWP_BYTE_ARRAY16* srcArr = inFixedValues->incomingValue[localAddrIdx].value.byteArray16; + FWP_BYTE_ARRAY16* dstArr = inFixedValues->incomingValue[remoteAddrIdx].value.byteArray16; + if (srcArr) { + RtlCopyMemory(entry->SrcAddr, srcArr->byteArray16, 16); + } + if (dstArr) { + RtlCopyMemory(entry->DstAddr, dstArr->byteArray16, 16); + } else { + ExFreePoolWithTag(entry, 'rniW'); + TriggerFatalAndPermitClassify(ctx, classifyOut, STATUS_INVALID_ADDRESS_COMPONENT, "ipv6 null destination"); + return; + } + } + entry->SrcPort = inFixedValues->incomingValue[localPortIdx].value.uint16; + entry->DstPort = inFixedValues->incomingValue[remotePortIdx].value.uint16; + + if (IsLoopbackAddress(addressFamily, entry->DstAddr)) { + ExFreePoolWithTag(entry, 'rniW'); + PermitClassify(classifyOut); + return; + } + + if (!TryBestRouteForEntry(&snapshot, entry, &bestRoute) || bestRoute == BestRouteOther) { + ExFreePoolWithTag(entry, 'rniW'); + PermitClassify(classifyOut); + return; + } + // Windows auto-redirect is best-effort: only redirect connections that are + // positively identified as already routed to the TUN. If route lookup says + // "not TUN" or fails for a context we do not currently characterize, leave + // the original connect alone instead of redirecting unknown traffic. + + // Extract PID from metadata + if (FWPS_IS_METADATA_FIELD_PRESENT(inMetaValues, FWPS_METADATA_FIELD_PROCESS_ID)) { + entry->ProcessID = (UINT32)inMetaValues->processId; + } + + if (!classifyContext) { + ExFreePoolWithTag(entry, 'rniW'); + TriggerFatalAndPermitClassify(ctx, classifyOut, STATUS_INVALID_DEVICE_STATE, "no classify context"); + return; + } + + // Pend the classify + status = FwpsAcquireClassifyHandle0((void*)classifyContext, 0, &classifyHandle); + if (!NT_SUCCESS(status)) { + ExFreePoolWithTag(entry, 'rniW'); + TriggerFatalAndPermitClassify(ctx, classifyOut, status, "acquire classify handle"); + return; + } + + entry->ClassifyHandle = classifyHandle; + entry->ClassifyOut = *classifyOut; + + status = FwpsAcquireWritableLayerDataPointer0( + classifyHandle, filter->filterId, 0, + &entry->WritableLayerData, classifyOut); + if (!NT_SUCCESS(status) || !entry->WritableLayerData) { + FwpsReleaseClassifyHandle0(classifyHandle); + ExFreePoolWithTag(entry, 'rniW'); + TriggerFatalAndPermitClassify(ctx, classifyOut, !NT_SUCCESS(status) ? status : STATUS_INVALID_DEVICE_STATE, "acquire writable layer data"); + return; + } + + status = FwpsPendClassify0(classifyHandle, filter->filterId, 0, classifyOut); + if (!NT_SUCCESS(status)) { + FwpsApplyModifiedLayerData0( + classifyHandle, + entry->WritableLayerData, + FWPS_CLASSIFY_FLAG_REAUTHORIZE_IF_MODIFIED_BY_OTHERS); + FwpsReleaseClassifyHandle0(classifyHandle); + ExFreePoolWithTag(entry, 'rniW'); + TriggerFatalAndPermitClassify(ctx, classifyOut, status, "pend classify"); + return; + } + + BlockClassify(classifyOut); + + KeQuerySystemTime(&entry->Timestamp); + entry->DeliveryState = PendingDeliveryQueued; + PendingInsert(ctx, entry); + WdfWorkItemEnqueue(ctx->PendingDeliveryWorkItem); +} + +void NTAPI ClassifyFnV4( + _In_ const FWPS_INCOMING_VALUES0* inFixedValues, + _In_ const FWPS_INCOMING_METADATA_VALUES0* inMetaValues, + _Inout_opt_ void* layerData, + _In_opt_ const void* classifyContext, + _In_ const FWPS_FILTER1* filter, + _In_ UINT64 flowContext, + _Inout_ FWPS_CLASSIFY_OUT0* classifyOut) +{ + ClassifyFnCommon(AF_INET, inFixedValues, inMetaValues, layerData, + classifyContext, filter, flowContext, classifyOut, + FWPS_FIELD_ALE_CONNECT_REDIRECT_V4_IP_LOCAL_ADDRESS, + FWPS_FIELD_ALE_CONNECT_REDIRECT_V4_IP_LOCAL_PORT, + FWPS_FIELD_ALE_CONNECT_REDIRECT_V4_IP_REMOTE_ADDRESS, + FWPS_FIELD_ALE_CONNECT_REDIRECT_V4_IP_REMOTE_PORT); +} + +void NTAPI ClassifyFnV6( + _In_ const FWPS_INCOMING_VALUES0* inFixedValues, + _In_ const FWPS_INCOMING_METADATA_VALUES0* inMetaValues, + _Inout_opt_ void* layerData, + _In_opt_ const void* classifyContext, + _In_ const FWPS_FILTER1* filter, + _In_ UINT64 flowContext, + _Inout_ FWPS_CLASSIFY_OUT0* classifyOut) +{ + ClassifyFnCommon(AF_INET6, inFixedValues, inMetaValues, layerData, + classifyContext, filter, flowContext, classifyOut, + FWPS_FIELD_ALE_CONNECT_REDIRECT_V6_IP_LOCAL_ADDRESS, + FWPS_FIELD_ALE_CONNECT_REDIRECT_V6_IP_LOCAL_PORT, + FWPS_FIELD_ALE_CONNECT_REDIRECT_V6_IP_REMOTE_ADDRESS, + FWPS_FIELD_ALE_CONNECT_REDIRECT_V6_IP_REMOTE_PORT); +} + +// --- Pending connection management --- + +void PendingInsert(_In_ PDRIVER_CONTEXT Ctx, _In_ PPENDING_ENTRY Entry) +{ + KIRQL oldIrql; + KeAcquireSpinLock(&Ctx->PendingLock, &oldIrql); + InsertTailList(&Ctx->PendingList, &Entry->ListEntry); + KeReleaseSpinLock(&Ctx->PendingLock, oldIrql); +} + +PPENDING_ENTRY PendingFindByID(_In_ PDRIVER_CONTEXT Ctx, _In_ UINT64 ConnID) +{ + PPENDING_ENTRY found = NULL; + KIRQL oldIrql; + KeAcquireSpinLock(&Ctx->PendingLock, &oldIrql); + PLIST_ENTRY entry = Ctx->PendingList.Flink; + while (entry != &Ctx->PendingList) { + PPENDING_ENTRY pending = CONTAINING_RECORD(entry, PENDING_ENTRY, ListEntry); + if (pending->ConnID == ConnID) { + RemoveEntryList(entry); + found = pending; + break; + } + entry = entry->Flink; + } + KeReleaseSpinLock(&Ctx->PendingLock, oldIrql); + return found; +} + +void PendingFlushAll(_In_ PDRIVER_CONTEXT Ctx, _In_ UINT32 Verdict) +{ + for (;;) { + KIRQL oldIrql; + PPENDING_ENTRY pending = NULL; + + KeAcquireSpinLock(&Ctx->PendingLock, &oldIrql); + if (!IsListEmpty(&Ctx->PendingList)) { + PLIST_ENTRY entry = RemoveHeadList(&Ctx->PendingList); + pending = CONTAINING_RECORD(entry, PENDING_ENTRY, ListEntry); + } + KeReleaseSpinLock(&Ctx->PendingLock, oldIrql); + + if (!pending) { + break; + } + + ExecuteVerdict(Ctx, pending, Verdict); + ExFreePoolWithTag(pending, 'rniW'); + } +} + +// --- Verdict execution --- + +void ExecuteVerdict(_In_ PDRIVER_CONTEXT Ctx, _In_ PPENDING_ENTRY Entry, _In_ UINT32 Verdict) +{ + FWPS_CLASSIFY_OUT0 classifyOut = Entry->ClassifyOut; + FWPS_CONNECT_REQUEST0* connReq = (FWPS_CONNECT_REQUEST0*)Entry->WritableLayerData; + NTSTATUS redirectStatus = STATUS_SUCCESS; + CONFIG_SNAPSHOT snapshot; + + if (Verdict == VERDICT_REDIRECT) { + snapshot = ReadConfigSnapshot(Ctx); + if (!connReq || + !snapshot.HasTunLuid || + snapshot.Config.RedirectPort == 0 || + snapshot.Config.ProxyPID == 0 || + Ctx->RedirectHandle == NULL) { + redirectStatus = STATUS_INVALID_DEVICE_STATE; + } else { + SOCKADDR_STORAGE* redirectContext = + (SOCKADDR_STORAGE*)ExAllocatePoolZero(NonPagedPoolNx, sizeof(SOCKADDR_STORAGE) * 2, 'rniW'); + if (!redirectContext) { + redirectStatus = STATUS_INSUFFICIENT_RESOURCES; + } else { + RtlCopyMemory(&redirectContext[0], &connReq->remoteAddressAndPort, sizeof(SOCKADDR_STORAGE)); + RtlCopyMemory(&redirectContext[1], &connReq->localAddressAndPort, sizeof(SOCKADDR_STORAGE)); + + if (Entry->AddressFamily == AF_INET) { + SOCKADDR_IN* localAddr = (SOCKADDR_IN*)&connReq->localAddressAndPort; + SOCKADDR_IN* addr = (SOCKADDR_IN*)&connReq->remoteAddressAndPort; + addr->sin_family = AF_INET; + if (localAddr->sin_addr.s_addr == 0) { + addr->sin_addr.s_addr = RtlUlongByteSwap(0x7F000001); // 127.0.0.1 + } else { + addr->sin_addr = localAddr->sin_addr; + } + addr->sin_port = RtlUshortByteSwap(snapshot.Config.RedirectPort); + } else if (Entry->AddressFamily == AF_INET6) { + SOCKADDR_IN6* localAddr = (SOCKADDR_IN6*)&connReq->localAddressAndPort; + SOCKADDR_IN6* addr = (SOCKADDR_IN6*)&connReq->remoteAddressAndPort; + if (IsAnyAddress(AF_INET6, localAddr->sin6_addr.u.Byte)) { + RtlZeroMemory(addr, sizeof(SOCKADDR_IN6)); + addr->sin6_family = AF_INET6; + addr->sin6_addr.u.Byte[15] = 1; // ::1 + } else { + *addr = *localAddr; + addr->sin6_family = AF_INET6; + } + addr->sin6_port = RtlUshortByteSwap(snapshot.Config.RedirectPort); + } else { + redirectStatus = STATUS_INVALID_PARAMETER; + } + + if (NT_SUCCESS(redirectStatus)) { + connReq->localRedirectHandle = Ctx->RedirectHandle; + connReq->localRedirectTargetPID = snapshot.Config.ProxyPID; + connReq->localRedirectContext = redirectContext; + connReq->localRedirectContextSize = sizeof(SOCKADDR_STORAGE) * 2; + } else { + ExFreePoolWithTag(redirectContext, 'rniW'); + } + } + } + + if (!NT_SUCCESS(redirectStatus)) { + TriggerFatal(Ctx, redirectStatus, "execute redirect"); + Verdict = VERDICT_PERMIT; + } + } + + if (Entry->WritableLayerData) { + FwpsApplyModifiedLayerData0( + Entry->ClassifyHandle, + Entry->WritableLayerData, + FWPS_CLASSIFY_FLAG_REAUTHORIZE_IF_MODIFIED_BY_OTHERS); + Entry->WritableLayerData = NULL; + } + + classifyOut.actionType = FWP_ACTION_PERMIT; + classifyOut.rights &= ~FWPS_RIGHT_ACTION_WRITE; + + FwpsCompleteClassify0(Entry->ClassifyHandle, 0, &classifyOut); + FwpsReleaseClassifyHandle0(Entry->ClassifyHandle); +} diff --git a/internal/winredirect/driver/winredirect.h b/internal/winredirect/driver/winredirect.h new file mode 100644 index 00000000..5a231fbb --- /dev/null +++ b/internal/winredirect/driver/winredirect.h @@ -0,0 +1,179 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Device names +#define DEVICE_NAME L"\\Device\\WinRedirect" +#define SYMLINK_NAME L"\\DosDevices\\WinRedirect" + +// IOCTL codes - must match Go types_windows.go +#define IOCTL_WINREDIRECT_SET_CONFIG CTL_CODE(FILE_DEVICE_NETWORK, 0x800, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define IOCTL_WINREDIRECT_START CTL_CODE(FILE_DEVICE_NETWORK, 0x801, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define IOCTL_WINREDIRECT_STOP CTL_CODE(FILE_DEVICE_NETWORK, 0x802, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define IOCTL_WINREDIRECT_GET_PENDING CTL_CODE(FILE_DEVICE_NETWORK, 0x803, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define IOCTL_WINREDIRECT_SET_VERDICT CTL_CODE(FILE_DEVICE_NETWORK, 0x804, METHOD_BUFFERED, FILE_ANY_ACCESS) +#define IOCTL_WINREDIRECT_GET_FATAL_INFO CTL_CODE(FILE_DEVICE_NETWORK, 0x805, METHOD_BUFFERED, FILE_ANY_ACCESS) + +// Verdict values +#define VERDICT_REDIRECT 0 +#define VERDICT_PERMIT 1 + +// Shared structures - must match Go types_windows.go layout + +#pragma pack(push, 1) + +typedef struct _WINREDIRECT_CONFIG { + UINT16 RedirectPort; + UINT8 _pad0[2]; + UINT32 ProxyPID; + GUID TunGuid; +} WINREDIRECT_CONFIG; + +typedef struct _WINREDIRECT_PENDING_CONN { + UINT64 ConnID; + UINT8 AddressFamily; + UINT8 _pad0[3]; + UINT8 SrcAddr[16]; + UINT16 SrcPort; + UINT8 _pad1[2]; + UINT8 DstAddr[16]; + UINT16 DstPort; + UINT8 _pad2[2]; + UINT32 ProcessID; +} WINREDIRECT_PENDING_CONN; + +typedef struct _WINREDIRECT_VERDICT { + UINT64 ConnID; + UINT32 Verdict; + UINT8 _pad0[4]; +} WINREDIRECT_VERDICT; + +typedef struct _WINREDIRECT_FATAL_INFO { + UINT32 Status; + CHAR Message[128]; +} WINREDIRECT_FATAL_INFO; + +#pragma pack(pop) + +typedef struct _CONFIG_SNAPSHOT { + WINREDIRECT_CONFIG Config; + NET_LUID TunLuid; + BOOLEAN HasTunLuid; +} CONFIG_SNAPSHOT, *PCONFIG_SNAPSHOT; + +typedef enum _PENDING_DELIVERY_STATE { + PendingDeliveryQueued = 0, + PendingDeliveryCopying = 1, + PendingDeliveryDelivered = 2, +} PENDING_DELIVERY_STATE; + +// Internal pending connection entry +typedef struct _PENDING_ENTRY { + LIST_ENTRY ListEntry; + UINT64 ConnID; + UINT64 ClassifyHandle; + UINT64 FilterId; + FWPS_CLASSIFY_OUT0 ClassifyOut; + PVOID WritableLayerData; + volatile LONG DeliveryState; + UINT8 AddressFamily; + UINT8 SrcAddr[16]; + UINT16 SrcPort; + UINT8 DstAddr[16]; + UINT16 DstPort; + UINT32 ProcessID; + LARGE_INTEGER Timestamp; +} PENDING_ENTRY, *PPENDING_ENTRY; + +// Global driver context +typedef struct _DRIVER_CONTEXT { + WDFDEVICE Device; + WDFQUEUE PendingIoctlQueue; + + // WFP handles + HANDLE EngineHandle; + UINT32 CalloutIdV4; + UINT32 CalloutIdV6; + UINT64 FilterIdV4; + UINT64 FilterIdV6; + HANDLE RedirectHandle; + + // Configuration (protected by ConfigLock) + KSPIN_LOCK ConfigLock; + WINREDIRECT_CONFIG Config; + NET_LUID TunLuid; + BOOLEAN HasTunLuid; + volatile LONG Running; + volatile LONG FatalStatus; + CHAR FatalMessage[128]; + + // Pending connections (protected by PendingLock) + LIST_ENTRY PendingList; + KSPIN_LOCK PendingLock; + volatile LONG64 NextConnID; + + // Timeout timer + work item + WDFTIMER TimeoutTimer; + WDFWORKITEM TimeoutWorkItem; + WDFWORKITEM PendingDeliveryWorkItem; + WDFWORKITEM FatalWorkItem; +} DRIVER_CONTEXT, *PDRIVER_CONTEXT; + +WDF_DECLARE_CONTEXT_TYPE_WITH_NAME(DRIVER_CONTEXT, GetDriverContext) + +// Function declarations +DRIVER_INITIALIZE DriverEntry; +EVT_WDF_DRIVER_UNLOAD EvtDriverUnload; +EVT_WDF_IO_QUEUE_IO_DEVICE_CONTROL EvtIoDeviceControl; +EVT_WDF_IO_QUEUE_IO_CANCELED_ON_QUEUE EvtIoCanceledOnQueue; +EVT_WDF_TIMER EvtTimeoutTimer; +EVT_WDF_WORKITEM EvtTimeoutWorkItem; +EVT_WDF_WORKITEM EvtPendingDeliveryWorkItem; +EVT_WDF_WORKITEM EvtFatalWorkItem; + +// WFP functions +NTSTATUS WfpSetup(_In_ PDRIVER_CONTEXT Ctx); +void WfpCleanup(_In_ PDRIVER_CONTEXT Ctx); + +// Classify callbacks +void NTAPI ClassifyFnV4( + _In_ const FWPS_INCOMING_VALUES0* inFixedValues, + _In_ const FWPS_INCOMING_METADATA_VALUES0* inMetaValues, + _Inout_opt_ void* layerData, + _In_opt_ const void* classifyContext, + _In_ const FWPS_FILTER1* filter, + _In_ UINT64 flowContext, + _Inout_ FWPS_CLASSIFY_OUT0* classifyOut +); + +void NTAPI ClassifyFnV6( + _In_ const FWPS_INCOMING_VALUES0* inFixedValues, + _In_ const FWPS_INCOMING_METADATA_VALUES0* inMetaValues, + _Inout_opt_ void* layerData, + _In_opt_ const void* classifyContext, + _In_ const FWPS_FILTER1* filter, + _In_ UINT64 flowContext, + _Inout_ FWPS_CLASSIFY_OUT0* classifyOut +); + +NTSTATUS NTAPI NotifyFn( + _In_ FWPS_CALLOUT_NOTIFY_TYPE notifyType, + _In_ const GUID* filterKey, + _Inout_ FWPS_FILTER1* filter +); + +// Pending management +void PendingInsert(_In_ PDRIVER_CONTEXT Ctx, _In_ PPENDING_ENTRY Entry); +PPENDING_ENTRY PendingFindByID(_In_ PDRIVER_CONTEXT Ctx, _In_ UINT64 ConnID); +void PendingFlushAll(_In_ PDRIVER_CONTEXT Ctx, _In_ UINT32 Verdict); + +// Verdict execution +void ExecuteVerdict(_In_ PDRIVER_CONTEXT Ctx, _In_ PPENDING_ENTRY Entry, _In_ UINT32 Verdict); diff --git a/internal/winredirect/driver/winredirect.inf b/internal/winredirect/driver/winredirect.inf new file mode 100644 index 00000000..4b671a0b --- /dev/null +++ b/internal/winredirect/driver/winredirect.inf @@ -0,0 +1,59 @@ +[Version] +Signature = "$WINDOWS NT$" +Class = WFPCALLOUTDRIVER +ClassGuid = {57465043-616C-6C6F-7574-5F636C617373} +Provider = %ProviderString% +CatalogFile = winredirect.cat +DriverVer = +PnpLockdown = 1 + +[DestinationDirs] +DefaultDestDir = 13 +WinRedirect.Files = 13 + +[DefaultInstall.NTamd64] +CopyFiles = WinRedirect.Files + +[DefaultInstall.NTarm64] +CopyFiles = WinRedirect.Files + +[DefaultInstall.NTarm] +CopyFiles = WinRedirect.Files + +[DefaultInstall.NTx86] +CopyFiles = WinRedirect.Files + +[DefaultInstall.NTamd64.Services] +AddService = WinRedirect,,WinRedirect.Service + +[DefaultInstall.NTarm64.Services] +AddService = WinRedirect,,WinRedirect.Service + +[DefaultInstall.NTarm.Services] +AddService = WinRedirect,,WinRedirect.Service + +[DefaultInstall.NTx86.Services] +AddService = WinRedirect,,WinRedirect.Service + +[WinRedirect.Service] +DisplayName = %ServiceName% +Description = %ServiceDesc% +ServiceBinary = %13%\winredirect.sys +ServiceType = 1 ; SERVICE_KERNEL_DRIVER +StartType = 3 ; SERVICE_DEMAND_START +ErrorControl = 1 ; SERVICE_ERROR_NORMAL + +[WinRedirect.Files] +winredirect.sys + +[SourceDisksNames] +1 = %DiskName% + +[SourceDisksFiles] +winredirect.sys = 1 + +[Strings] +ProviderString = "sing-tun" +ServiceName = "WinRedirect" +ServiceDesc = "WFP TCP Connection Redirect Driver" +DiskName = "WinRedirect Installation Disk" diff --git a/internal/winredirect/driver/winredirect.sln b/internal/winredirect/driver/winredirect.sln new file mode 100644 index 00000000..cb39766a --- /dev/null +++ b/internal/winredirect/driver/winredirect.sln @@ -0,0 +1,40 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.0.0.0 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "winredirect", "winredirect.vcxproj", "{A1B2C3D4-1234-5678-9ABC-DEF012345678}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|ARM = Debug|ARM + Debug|ARM64 = Debug|ARM64 + Debug|x64 = Debug|x64 + Debug|Win32 = Debug|Win32 + Release|ARM = Release|ARM + Release|ARM64 = Release|ARM64 + Release|x64 = Release|x64 + Release|Win32 = Release|Win32 + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Debug|ARM.ActiveCfg = Debug|ARM + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Debug|ARM.Build.0 = Debug|ARM + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Debug|ARM64.ActiveCfg = Debug|ARM64 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Debug|ARM64.Build.0 = Debug|ARM64 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Debug|x64.ActiveCfg = Debug|x64 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Debug|x64.Build.0 = Debug|x64 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Debug|Win32.ActiveCfg = Debug|Win32 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Debug|Win32.Build.0 = Debug|Win32 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Release|ARM.ActiveCfg = Release|ARM + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Release|ARM.Build.0 = Release|ARM + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Release|ARM64.ActiveCfg = Release|ARM64 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Release|ARM64.Build.0 = Release|ARM64 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Release|x64.ActiveCfg = Release|x64 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Release|x64.Build.0 = Release|x64 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Release|Win32.ActiveCfg = Release|Win32 + {A1B2C3D4-1234-5678-9ABC-DEF012345678}.Release|Win32.Build.0 = Release|Win32 + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection +EndGlobal diff --git a/internal/winredirect/driver/winredirect.vcxproj b/internal/winredirect/driver/winredirect.vcxproj new file mode 100644 index 00000000..8d3afe3b --- /dev/null +++ b/internal/winredirect/driver/winredirect.vcxproj @@ -0,0 +1,111 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + Debug + ARM + + + Release + ARM + + + Debug + ARM64 + + + Release + ARM64 + + + + {A1B2C3D4-1234-5678-9ABC-DEF012345678} + {1bc93793-694f-48fe-9372-81e2b05556fd} + winredirect + 12.0 + winredirect + + + + + true + + + false + + + WindowsKernelModeDriver10.0 + Driver + KMDF + Desktop + Windows10 + <_NT_TARGET_VERSION>0xA000002 + false + + + + + + + + + + $(SolutionDir)build\$(Platform)\$(Configuration)\ + $(SolutionDir)intermediate\$(Platform)\$(Configuration)\ + + + + + %(AdditionalIncludeDirectories) + NDIS_WDM;NDIS630;NDIS_SUPPORT_NDIS6;%(PreprocessorDefinitions) + false + false + + + Fwpkclnt.lib;$(DDK_LIB_PATH)netio.lib;$(DDK_LIB_PATH)wdmsec.lib;$(KernelBufferOverflowLib);$(DDK_LIB_PATH)ntoskrnl.lib;$(DDK_LIB_PATH)hal.lib;$(DDK_LIB_PATH)wmilib.lib;$(KMDF_LIB_PATH)$(KMDF_VER_PATH)\WdfLdr.lib;$(KMDF_LIB_PATH)$(KMDF_VER_PATH)\WdfDriverEntry.lib;%(AdditionalDependencies) + false + + + true + + + false + + + sha256 + + + + + + + + + + + + true + + + + + + + + diff --git a/internal/winredirect/embed_windows_386.go b/internal/winredirect/embed_windows_386.go new file mode 100644 index 00000000..e941d3b5 --- /dev/null +++ b/internal/winredirect/embed_windows_386.go @@ -0,0 +1,6 @@ +package winredirect + +import _ "embed" + +//go:embed x86/winredirect.sys +var driverContent []byte diff --git a/internal/winredirect/embed_windows_amd64.go b/internal/winredirect/embed_windows_amd64.go new file mode 100644 index 00000000..2013e198 --- /dev/null +++ b/internal/winredirect/embed_windows_amd64.go @@ -0,0 +1,6 @@ +package winredirect + +import _ "embed" + +//go:embed amd64/winredirect.sys +var driverContent []byte diff --git a/internal/winredirect/embed_windows_arm.go b/internal/winredirect/embed_windows_arm.go new file mode 100644 index 00000000..46a3ba6f --- /dev/null +++ b/internal/winredirect/embed_windows_arm.go @@ -0,0 +1,6 @@ +package winredirect + +import _ "embed" + +//go:embed arm/winredirect.sys +var driverContent []byte diff --git a/internal/winredirect/embed_windows_arm64.go b/internal/winredirect/embed_windows_arm64.go new file mode 100644 index 00000000..4a108639 --- /dev/null +++ b/internal/winredirect/embed_windows_arm64.go @@ -0,0 +1,6 @@ +package winredirect + +import _ "embed" + +//go:embed arm64/winredirect.sys +var driverContent []byte diff --git a/internal/winredirect/ioctl_windows.go b/internal/winredirect/ioctl_windows.go new file mode 100644 index 00000000..cfe1ee36 --- /dev/null +++ b/internal/winredirect/ioctl_windows.go @@ -0,0 +1,65 @@ +package winredirect + +import ( + "unsafe" +) + +func (m *Manager) SetConfig(cfg *Config) error { + _, err := m.ioctl( + ioctlSetConfig, + unsafe.Pointer(cfg), + uint32(unsafe.Sizeof(*cfg)), + nil, 0, + ) + return err +} + +func (m *Manager) StartRedirect() error { + _, err := m.ioctl(ioctlStart, nil, 0, nil, 0) + return err +} + +func (m *Manager) StopRedirect() error { + _, err := m.ioctl(ioctlStop, nil, 0, nil, 0) + return err +} + +// GetPendingConn blocks until a connection needs a verdict. +// Multiple goroutines may call this concurrently (inverted IOCTL pattern). +func (m *Manager) GetPendingConn() (*PendingConn, error) { + var conn PendingConn + _, err := m.ioctl( + ioctlGetPending, + nil, 0, + unsafe.Pointer(&conn), + uint32(unsafe.Sizeof(conn)), + ) + if err != nil { + return nil, err + } + return &conn, nil +} + +func (m *Manager) SetVerdict(v *Verdict) error { + _, err := m.ioctl( + ioctlSetVerdict, + unsafe.Pointer(v), + uint32(unsafe.Sizeof(*v)), + nil, 0, + ) + return err +} + +func (m *Manager) GetFatalInfo() (*FatalInfo, error) { + var info FatalInfo + _, err := m.ioctl( + ioctlGetFatalInfo, + nil, 0, + unsafe.Pointer(&info), + uint32(unsafe.Sizeof(info)), + ) + if err != nil { + return nil, err + } + return &info, nil +} diff --git a/internal/winredirect/manager_windows.go b/internal/winredirect/manager_windows.go new file mode 100644 index 00000000..83a9dfc6 --- /dev/null +++ b/internal/winredirect/manager_windows.go @@ -0,0 +1,154 @@ +package winredirect + +import ( + "fmt" + "os" + "path/filepath" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/mgr" +) + +const ( + serviceName = "WinRedirect" + devicePath = `\\.\WinRedirect` +) + +type Manager struct { + driverPath string + device windows.Handle +} + +func NewManager() (*Manager, error) { + tmpDir, err := os.MkdirTemp("", "winredirect-*") + if err != nil { + return nil, fmt.Errorf("create temp dir: %w", err) + } + driverPath := filepath.Join(tmpDir, "winredirect.sys") + if err = os.WriteFile(driverPath, driverContent, 0o644); err != nil { + os.RemoveAll(tmpDir) + return nil, fmt.Errorf("write driver: %w", err) + } + return &Manager{ + driverPath: driverPath, + device: windows.InvalidHandle, + }, nil +} + +func (m *Manager) Install() error { + scm, err := mgr.Connect() + if err != nil { + return fmt.Errorf("connect SCM: %w", err) + } + defer scm.Disconnect() + + // Remove stale service if it exists + if existing, err := scm.OpenService(serviceName); err == nil { + existing.Control(windows.SERVICE_CONTROL_STOP) + existing.Delete() + existing.Close() + } + + svc, err := scm.CreateService(serviceName, m.driverPath, mgr.Config{ + ServiceType: windows.SERVICE_KERNEL_DRIVER, + StartType: mgr.StartManual, + }) + if err != nil { + return fmt.Errorf("create service: %w", err) + } + svc.Close() + return nil +} + +func (m *Manager) Start() error { + scm, err := mgr.Connect() + if err != nil { + return fmt.Errorf("connect SCM: %w", err) + } + defer scm.Disconnect() + + svc, err := scm.OpenService(serviceName) + if err != nil { + return fmt.Errorf("open service: %w", err) + } + defer svc.Close() + + if err = svc.Start(); err != nil { + return fmt.Errorf("start service: %w", err) + } + return nil +} + +func (m *Manager) OpenDevice() error { + path, err := windows.UTF16PtrFromString(devicePath) + if err != nil { + return err + } + handle, err := windows.CreateFile( + path, + windows.GENERIC_READ|windows.GENERIC_WRITE, + 0, + nil, + windows.OPEN_EXISTING, + windows.FILE_FLAG_OVERLAPPED, + 0, + ) + if err != nil { + return fmt.Errorf("open device: %w", err) + } + m.device = handle + return nil +} + +func (m *Manager) ioctl(code uint32, inBuf unsafe.Pointer, inSize uint32, outBuf unsafe.Pointer, outSize uint32) (uint32, error) { + var bytesReturned uint32 + overlapped := &windows.Overlapped{} + overlapped.HEvent, _ = windows.CreateEvent(nil, 1, 0, nil) + defer windows.CloseHandle(overlapped.HEvent) + + err := windows.DeviceIoControl( + m.device, + code, + (*byte)(inBuf), + inSize, + (*byte)(outBuf), + outSize, + &bytesReturned, + overlapped, + ) + if err == windows.ERROR_IO_PENDING { + _, err = windows.WaitForSingleObject(overlapped.HEvent, windows.INFINITE) + if err != nil { + return 0, err + } + err = windows.GetOverlappedResult(m.device, overlapped, &bytesReturned, false) + } + if err != nil { + return 0, err + } + return bytesReturned, nil +} + +func (m *Manager) Close() error { + if m.device != windows.InvalidHandle { + m.StopRedirect() + windows.CloseHandle(m.device) + m.device = windows.InvalidHandle + } + + scm, err := mgr.Connect() + if err == nil { + if svc, err := scm.OpenService(serviceName); err == nil { + svc.Control(windows.SERVICE_CONTROL_STOP) + svc.Delete() + svc.Close() + } + scm.Disconnect() + } + + if m.driverPath != "" { + os.RemoveAll(filepath.Dir(m.driverPath)) + } + return nil +} diff --git a/internal/winredirect/types_windows.go b/internal/winredirect/types_windows.go new file mode 100644 index 00000000..0dda8c0d --- /dev/null +++ b/internal/winredirect/types_windows.go @@ -0,0 +1,58 @@ +package winredirect + +// IOCTL codes matching the kernel driver definitions. +// CTL_CODE(FILE_DEVICE_NETWORK=0x12, function, METHOD_BUFFERED=0, FILE_ANY_ACCESS=0) +const ( + ioctlSetConfig = (0x00120000 | (0x800 << 2)) // IOCTL_WINREDIRECT_SET_CONFIG + ioctlStart = (0x00120000 | (0x801 << 2)) // IOCTL_WINREDIRECT_START + ioctlStop = (0x00120000 | (0x802 << 2)) // IOCTL_WINREDIRECT_STOP + ioctlGetPending = (0x00120000 | (0x803 << 2)) // IOCTL_WINREDIRECT_GET_PENDING + ioctlSetVerdict = (0x00120000 | (0x804 << 2)) // IOCTL_WINREDIRECT_SET_VERDICT + ioctlGetFatalInfo = (0x00120000 | (0x805 << 2)) // IOCTL_WINREDIRECT_GET_FATAL_INFO +) + +const ( + VerdictRedirect = 0 + // VerdictPermit allows the original TUN-bound connect to continue + // without local redirection. + VerdictPermit = 1 +) + +// Config is sent to the driver via IOCTL_SET_CONFIG. +// Must match WINREDIRECT_CONFIG in the driver. +type Config struct { + RedirectPort uint16 + _ [2]byte // padding + ProxyPID uint32 + TunGUID [16]byte +} + +// PendingConn is received from the driver via IOCTL_GET_PENDING. +// Must match WINREDIRECT_PENDING_CONN in the driver. +type PendingConn struct { + ConnID uint64 + AddressFamily uint8 + _ [3]byte // padding + SrcAddr [16]byte + SrcPort uint16 + _ [2]byte // padding + DstAddr [16]byte + DstPort uint16 + _ [2]byte // padding + ProcessID uint32 +} + +// Verdict is sent to the driver via IOCTL_SET_VERDICT. +// Must match WINREDIRECT_VERDICT in the driver. +type Verdict struct { + ConnID uint64 + Verdict uint32 + _ [4]byte // padding for alignment +} + +// FatalInfo is received from the driver via IOCTL_GET_FATAL_INFO. +// Must match WINREDIRECT_FATAL_INFO in the driver. +type FatalInfo struct { + Status uint32 + Message [128]byte +} diff --git a/internal/winredirect/x86/winredirect.sys b/internal/winredirect/x86/winredirect.sys new file mode 100644 index 00000000..a866793a Binary files /dev/null and b/internal/winredirect/x86/winredirect.sys differ diff --git a/internal/winsys/constants.go b/internal/winsys/constants.go index 1173ea9a..e71b4f74 100644 --- a/internal/winsys/constants.go +++ b/internal/winsys/constants.go @@ -132,9 +132,30 @@ var FWPM_CONDITION_ALE_APP_ID = windows.GUID{ } const ( + IPPROTO_TCP uint32 = 6 IPPROTO_UDP uint32 = 17 ) +// https://learn.microsoft.com/en-us/windows-hardware/drivers/network/ale-connect-redirect-layers +var FWPM_LAYER_ALE_CONNECT_REDIRECT_V4 = windows.GUID{ + Data1: 0xc4f7e4c3, + Data2: 0x3455, + Data3: 0x4c3a, + Data4: [8]byte{0xa2, 0x17, 0x31, 0x7c, 0x8f, 0xc8, 0xf0, 0xd1}, +} + +var FWPM_LAYER_ALE_CONNECT_REDIRECT_V6 = windows.GUID{ + Data1: 0x587e54a7, + Data2: 0x8440, + Data3: 0x4b2a, + Data4: [8]byte{0xa3, 0x53, 0x2e, 0x45, 0xd8, 0x80, 0x25, 0x4f}, +} + +const ( + FWP_ACTION_FLAG_CALLOUT uint32 = 0x00004000 + FWP_ACTION_CALLOUT_TERMINATING uint32 = (0x00000003 | FWP_ACTION_FLAG_CALLOUT | FWP_ACTION_FLAG_TERMINATING) +) + const ( FWP_ACTION_FLAG_TERMINATING uint32 = 0x00001000 FWP_ACTION_BLOCK uint32 = (0x00000001 | FWP_ACTION_FLAG_TERMINATING) diff --git a/nfqueue_linux.go b/nfqueue_linux.go index baaefb54..4b8a6eba 100644 --- a/nfqueue_linux.go +++ b/nfqueue_linux.go @@ -212,7 +212,7 @@ func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int { return 0 } - _, pErr := h.handler.PrepareConnection(N.NetworkTCP, srcAddr, dstAddr, nil, 0) + _, pErr := h.handler.PrepareConnection(h.ctx, N.NetworkTCP, srcAddr, dstAddr, nil, 0) // Use NfRepeat for bypass/reset so the packet re-enters the chain // from the beginning, allowing mark-checking rules to save the mark diff --git a/redirect.go b/redirect.go index dcf3e720..5a04528a 100644 --- a/redirect.go +++ b/redirect.go @@ -25,8 +25,10 @@ type AutoRedirect interface { type AutoRedirectOptions struct { TunOptions *Options Context context.Context + ConnContext func(ctx context.Context) context.Context Handler Handler Logger logger.Logger + ErrorHandler func(error) NetworkMonitor NetworkUpdateMonitor InterfaceFinder control.InterfaceFinder TableName string diff --git a/redirect_metadata.go b/redirect_metadata.go new file mode 100644 index 00000000..3a40e321 --- /dev/null +++ b/redirect_metadata.go @@ -0,0 +1,24 @@ +package tun + +import "context" + +// AutoRedirectMetadata carries process info obtained cheaply at redirect time. +// On Windows, WFP classify provides PID for free; Go resolves path from PID. +// This replaces the expensive process finder (netlink diag / sysctl / GetExtendedTcpTable) +// used in sing-box when process-based routing rules are configured. +type AutoRedirectMetadata struct { + ProcessID uint32 + ProcessPath string + UserId int32 // -1 if unknown +} + +type autoRedirectMetadataKey struct{} + +func ContextWithAutoRedirectMetadata(ctx context.Context, metadata *AutoRedirectMetadata) context.Context { + return context.WithValue(ctx, autoRedirectMetadataKey{}, metadata) +} + +func AutoRedirectMetadataFromContext(ctx context.Context) *AutoRedirectMetadata { + metadata, _ := ctx.Value(autoRedirectMetadataKey{}).(*AutoRedirectMetadata) + return metadata +} diff --git a/redirect_nftables_rules.go b/redirect_nftables_rules.go index 9f950a38..7e47be08 100644 --- a/redirect_nftables_rules.go +++ b/redirect_nftables_rules.go @@ -3,6 +3,7 @@ package tun import ( + "net" "net/netip" _ "unsafe" @@ -11,9 +12,8 @@ import ( "github.com/sagernet/nftables/expr" "github.com/sagernet/nftables/userdata" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/ranges" - E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/ranges" "golang.org/x/exp/slices" "golang.org/x/sys/unix" @@ -375,6 +375,149 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft }) } } + if len(r.tunOptions.IncludeMACAddress) > 0 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFTYPE, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint16(unix.ARPHRD_ETHER), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + if len(r.tunOptions.IncludeMACAddress) > 1 { + includeMACSet := &nftables.Set{ + Table: table, + Anonymous: true, + Constant: true, + KeyType: nftables.TypeEtherAddr, + } + err := nft.AddSet(includeMACSet, common.Map(r.tunOptions.IncludeMACAddress, func(it net.HardwareAddr) nftables.SetElement { + return nftables.SetElement{ + Key: []byte(it), + } + })) + if err != nil { + return err + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: includeMACSet.ID, + SetName: includeMACSet.Name, + Invert: true, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } else { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte(r.tunOptions.IncludeMACAddress[0]), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } + } + if len(r.tunOptions.ExcludeMACAddress) > 0 { + if len(r.tunOptions.ExcludeMACAddress) > 1 { + excludeMACSet := &nftables.Set{ + Table: table, + Anonymous: true, + Constant: true, + KeyType: nftables.TypeEtherAddr, + } + err := nft.AddSet(excludeMACSet, common.Map(r.tunOptions.ExcludeMACAddress, func(it net.HardwareAddr) nftables.SetElement { + return nftables.SetElement{ + Key: []byte(it), + } + })) + if err != nil { + return err + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: excludeMACSet.ID, + SetName: excludeMACSet.Name, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } else { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(r.tunOptions.ExcludeMACAddress[0]), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } + } } else { if len(r.tunOptions.IncludeUID) > 0 { if len(r.tunOptions.IncludeUID) > 1 || r.tunOptions.IncludeUID[0].Start != r.tunOptions.IncludeUID[0].End { diff --git a/redirect_route_linux.go b/redirect_route_linux.go index db79cac6..7e0868c6 100644 --- a/redirect_route_linux.go +++ b/redirect_route_linux.go @@ -8,9 +8,9 @@ import ( "net/netip" "github.com/sagernet/netlink" - E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" "golang.org/x/sys/unix" ) diff --git a/redirect_server.go b/redirect_server.go index 7590b35d..822322fb 100644 --- a/redirect_server.go +++ b/redirect_server.go @@ -26,6 +26,12 @@ type redirectServer struct { inShutdown atomic.Bool } +func (s *redirectServer) logError(args ...any) { + if s.logger != nil { + s.logger.Error(args...) + } +} + func newRedirectServer(ctx context.Context, handler N.TCPConnectionHandlerEx, logger logger.Logger, listenAddr netip.Addr) *redirectServer { return &redirectServer{ ctx: ctx, @@ -60,14 +66,14 @@ func (s *redirectServer) loopIn() { var netError net.Error //nolint:staticcheck if errors.As(err, &netError) && netError.Temporary() { - s.logger.Error(err) + s.logError(err) continue } if s.inShutdown.Load() && E.IsClosed(err) { return } s.listener.Close() - s.logger.Error("serve error: ", err) + s.logError("serve error: ", err) continue } source := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap() @@ -75,7 +81,7 @@ func (s *redirectServer) loopIn() { if err != nil { _ = conn.SetLinger(0) _ = conn.Close() - s.logger.Error("process redirect connection from ", source, ": invalid connection: ", err) + s.logError("process redirect connection from ", source, ": invalid connection: ", err) continue } go s.handler.NewConnectionEx(s.ctx, conn, source, M.SocksaddrFromNetIP(destination).Unwrap(), nil) diff --git a/redirect_server_windows.go b/redirect_server_windows.go new file mode 100644 index 00000000..22cd6fba --- /dev/null +++ b/redirect_server_windows.go @@ -0,0 +1,217 @@ +package tun + +import ( + "context" + "errors" + "net" + "net/netip" + "sync" + "sync/atomic" + "time" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type redirectServer struct { + ctx context.Context + handler N.TCPConnectionHandlerEx + logger logger.Logger + onFatal func(error) + listenAddr netip.Addr + listener *net.TCPListener + connTable *connMetadataTable + inShutdown atomic.Bool +} + +func (s *redirectServer) logError(args ...any) { + if s.logger != nil { + s.logger.Error(args...) + } +} + +func newRedirectServerWindows(ctx context.Context, handler N.TCPConnectionHandlerEx, logger logger.Logger, listenAddr netip.Addr, onFatal func(error)) *redirectServer { + return &redirectServer{ + ctx: ctx, + handler: handler, + logger: logger, + onFatal: onFatal, + listenAddr: listenAddr, + connTable: newConnMetadataTable(), + } +} + +func (s *redirectServer) Start() error { + var listenConfig net.ListenConfig + listenConfig.KeepAlive = 10 * time.Minute + listener, err := listenConfig.Listen(s.ctx, M.NetworkFromNetAddr("tcp", s.listenAddr), M.SocksaddrFrom(s.listenAddr, 0).String()) + if err != nil { + return err + } + s.listener = listener.(*net.TCPListener) + go s.loopIn() + return nil +} + +func (s *redirectServer) Close() error { + s.inShutdown.Store(true) + if s.connTable != nil { + s.connTable.Close() + } + if s.listener != nil { + return s.listener.Close() + } + return nil +} + +func (s *redirectServer) loopIn() { + for { + conn, err := s.listener.AcceptTCP() + if err != nil { + var netError net.Error + //nolint:staticcheck + if errors.As(err, &netError) && netError.Temporary() { + s.logError(err) + continue + } + if s.inShutdown.Load() && E.IsClosed(err) { + return + } + s.listener.Close() + if s.onFatal != nil { + s.onFatal(E.Cause(err, "accept redirect connection")) + } else { + s.logError("serve error: ", err) + } + return + } + source := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap() + entry, ok := s.connTable.Lookup(source) + if !ok { + _ = conn.SetLinger(0) + _ = conn.Close() + s.logError("process redirect connection from ", source, ": no metadata") + continue + } + destination := entry.Destination + if entry.IsDNS { + destination = entry.DNSServer + } + ctx := entry.Context + if ctx == nil { + ctx = s.ctx + } + go s.handler.NewConnectionEx(ctx, conn, source, destination, nil) + } +} + +type connMetadataTable struct { + mu sync.Mutex + entries map[connKey]*connEntry + done chan struct{} +} + +type connKey struct { + Addr netip.Addr + Port uint16 +} + +type connEntry struct { + Destination M.Socksaddr + Context context.Context + IsDNS bool + DNSServer M.Socksaddr + CreatedAt time.Time +} + +func newConnMetadataTable() *connMetadataTable { + t := &connMetadataTable{ + entries: make(map[connKey]*connEntry), + done: make(chan struct{}), + } + go t.cleanupLoop() + return t +} + +func (t *connMetadataTable) Close() { + select { + case <-t.done: + default: + close(t.done) + } +} + +func (t *connMetadataTable) Store(src M.Socksaddr, dst M.Socksaddr, ctx context.Context) { + key := connKey{Addr: src.Addr, Port: src.Port} + t.mu.Lock() + defer t.mu.Unlock() + t.entries[key] = &connEntry{ + Destination: dst, + Context: ctx, + CreatedAt: time.Now(), + } +} + +func (t *connMetadataTable) StoreDNS(src M.Socksaddr, originalDst M.Socksaddr, dnsServer M.Socksaddr, ctx context.Context) { + key := connKey{Addr: src.Addr, Port: src.Port} + t.mu.Lock() + defer t.mu.Unlock() + t.entries[key] = &connEntry{ + Destination: originalDst, + Context: ctx, + IsDNS: true, + DNSServer: dnsServer, + CreatedAt: time.Now(), + } +} + +func (t *connMetadataTable) Lookup(src M.Socksaddr) (*connEntry, bool) { + key := connKey{Addr: src.Addr, Port: src.Port} + t.mu.Lock() + defer t.mu.Unlock() + entry, ok := t.entries[key] + if ok { + delete(t.entries, key) + return entry, true + } + + // ALE_CONNECT_REDIRECT may report an unspecified local source address + // (0.0.0.0/::) before connect completes, while the accepted redirected + // connection later appears as loopback with the same source port. + if src.Addr.IsLoopback() { + var fallbackAddr netip.Addr + if src.Addr.Is4() { + fallbackAddr = netip.IPv4Unspecified() + } else { + fallbackAddr = netip.IPv6Unspecified() + } + fallbackKey := connKey{Addr: fallbackAddr, Port: src.Port} + entry, ok = t.entries[fallbackKey] + if ok { + delete(t.entries, fallbackKey) + } + } + return entry, ok +} + +func (t *connMetadataTable) cleanupLoop() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-t.done: + return + case <-ticker.C: + t.mu.Lock() + now := time.Now() + for key, entry := range t.entries { + if now.Sub(entry.CreatedAt) > 30*time.Second { + delete(t.entries, key) + } + } + t.mu.Unlock() + } + } +} diff --git a/redirect_stub.go b/redirect_stub.go index 040ef124..8bd94ca2 100644 --- a/redirect_stub.go +++ b/redirect_stub.go @@ -1,4 +1,4 @@ -//go:build !linux +//go:build !(linux || windows) package tun diff --git a/redirect_windows.go b/redirect_windows.go new file mode 100644 index 00000000..849449a8 --- /dev/null +++ b/redirect_windows.go @@ -0,0 +1,437 @@ +package tun + +import ( + "context" + "errors" + "net" + "net/netip" + "os" + "slices" + "strings" + "sync" + "sync/atomic" + "unsafe" + + "github.com/sagernet/sing-tun/internal/winipcfg" + "github.com/sagernet/sing-tun/internal/winredirect" + "github.com/sagernet/sing-tun/internal/winsys" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/x/list" + + "golang.org/x/sys/windows" +) + +type autoRedirect struct { + tunOptions *Options + ctx context.Context + connContext func(context.Context) context.Context + handler Handler + logger logger.Logger + errorHandler func(error) + networkMonitor NetworkUpdateMonitor + networkListener *list.Element[NetworkUpdateCallback] + interfaceFinder control.InterfaceFinder + driverManager *winredirect.Manager + redirectServer *redirectServer + + selfPID uint32 + enableIPv4 bool + enableIPv6 bool + + localAddressMu sync.RWMutex + localAddresses []netip.Prefix + + closing atomic.Bool + closeOnce sync.Once + closeErr error + fatalOnce sync.Once +} + +func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { + r := &autoRedirect{ + tunOptions: options.TunOptions, + ctx: options.Context, + connContext: options.ConnContext, + handler: options.Handler, + logger: options.Logger, + errorHandler: options.ErrorHandler, + networkMonitor: options.NetworkMonitor, + interfaceFinder: options.InterfaceFinder, + } + return r, nil +} + +func (r *autoRedirect) Start() error { + r.selfPID = uint32(os.Getpid()) + r.enableIPv4 = len(r.tunOptions.Inet4Address) > 0 + r.enableIPv6 = len(r.tunOptions.Inet6Address) > 0 + if !r.enableIPv4 && !r.enableIPv6 { + return E.New("no address configured") + } + err := r.startRedirect() + if err != nil { + common.Close( + common.PtrOrNil(r.redirectServer), + common.PtrOrNil(r.driverManager), + ) + r.redirectServer = nil + r.driverManager = nil + return err + } + r.updateLocalAddresses() + if r.networkMonitor != nil { + r.networkListener = r.networkMonitor.RegisterCallback(func() { + r.updateLocalAddresses() + }) + } + go r.dispatchLoop() + return nil +} + +func (r *autoRedirect) startRedirect() error { + manager, err := winredirect.NewManager() + if err != nil { + return E.Cause(err, "create driver manager") + } + r.driverManager = manager + + err = manager.Install() + if err != nil { + return E.Cause(err, "install driver") + } + + err = manager.Start() + if err != nil { + return E.Cause(err, "start driver") + } + + err = manager.OpenDevice() + if err != nil { + return E.Cause(err, "open driver device") + } + + var listenAddr netip.Addr + if r.enableIPv6 { + listenAddr = netip.IPv6Unspecified() + } else { + listenAddr = netip.IPv4Unspecified() + } + server := newRedirectServerWindows(r.ctx, r.handler, r.logger, listenAddr, r.handleFatalError) + r.redirectServer = server + err = server.Start() + if err != nil { + return E.Cause(err, "start redirect server") + } + + tunGUID, err := r.resolveTunInterfaceGUID() + if err != nil { + return E.Cause(err, "resolve tun interface") + } + + redirectPort := M.AddrPortFromNet(server.listener.Addr()).Port() + err = manager.SetConfig(&winredirect.Config{ + RedirectPort: redirectPort, + ProxyPID: r.selfPID, + TunGUID: tunGUID, + }) + if err != nil { + return E.Cause(err, "set driver config") + } + + err = manager.StartRedirect() + if err != nil { + return E.Cause(err, "start redirect") + } + + return nil +} + +func (r *autoRedirect) Close() error { + r.closing.Store(true) + r.closeOnce.Do(func() { + if r.networkMonitor != nil && r.networkListener != nil { + r.networkMonitor.UnregisterCallback(r.networkListener) + r.networkListener = nil + } + r.closeErr = common.Close( + common.PtrOrNil(r.redirectServer), + common.PtrOrNil(r.driverManager), + ) + }) + return r.closeErr +} + +func (r *autoRedirect) UpdateRouteAddressSet() { +} + +func (r *autoRedirect) dispatchLoop() { + for { + conn, err := r.driverManager.GetPendingConn() + if err != nil { + if !r.closing.Load() { + r.handleFatalError(E.Cause(r.enrichDriverError(err), "get pending connection")) + } + return + } + go r.handlePendingConn(conn) + } +} + +func (r *autoRedirect) handlePendingConn(conn *winredirect.PendingConn) { + verdict := r.evaluateConnection(conn) + err := r.driverManager.SetVerdict(&winredirect.Verdict{ + ConnID: conn.ConnID, + Verdict: verdict, + }) + if err != nil { + if isStaleVerdictError(err) { + return + } + if !r.closing.Load() { + r.handleFatalError(E.Cause(r.enrichDriverError(err), "set redirect verdict")) + } + } +} + +func isStaleVerdictError(err error) bool { + return errors.Is(err, windows.ERROR_NOT_FOUND) || errors.Is(err, windows.ERROR_FILE_NOT_FOUND) +} + +func (r *autoRedirect) enrichDriverError(err error) error { + fatalInfo, infoErr := r.driverManager.GetFatalInfo() + if infoErr != nil || fatalInfo.Status == 0 { + return err + } + message := strings.TrimRight(string(fatalInfo.Message[:]), "\x00") + if message == "" { + return err + } + return E.Cause(err, E.Cause(windows.Errno(fatalInfo.Status), message)) +} + +func (r *autoRedirect) handleFatalError(err error) { + if err == nil || r.closing.Load() { + return + } + r.fatalOnce.Do(func() { + if r.logger != nil { + r.logger.Error("windows auto-redirect fatal error: ", err) + } + _ = r.Close() + if r.errorHandler != nil { + r.errorHandler(err) + } + }) +} + +func (r *autoRedirect) evaluateConnection(conn *winredirect.PendingConn) uint32 { + dst := pendingConnDst(conn) + src := pendingConnSrc(conn) + + // Proxy process outbound connections must never be redirected back into itself. + if conn.ProcessID == r.selfPID { + return winredirect.VerdictPermit + } + + if dst.Addr.IsLoopback() { + return winredirect.VerdictPermit + } + + if !r.tunOptions.EXP_DisableDNSHijack && dst.Port == 53 { + if r.isLocalAddress(src.Addr) { + dnsServer := r.dnsServerForFamily(dst.Addr) + if dnsServer.IsValid() { + metadata := r.resolveMetadata(conn) + ctx := r.newConnContext(metadata) + _, err := r.handler.PrepareConnection(ctx, "tcp", src, dst, nil, 0) + if errors.Is(err, ErrDrop) { + return winredirect.VerdictPermit + } + r.redirectServer.connTable.StoreDNS(src, dst, M.SocksaddrFrom(dnsServer, 53), ctx) + return winredirect.VerdictRedirect + } + } + } + + if r.tunOptions.StrictRoute && r.isDisabledFamily(dst.Addr) { + return winredirect.VerdictPermit + } + + metadata := r.resolveMetadata(conn) + ctx := r.newConnContext(metadata) + + _, err := r.handler.PrepareConnection(ctx, N.NetworkTCP, src, dst, nil, 0) + if errors.Is(err, ErrDrop) { + return winredirect.VerdictPermit + } + if errors.Is(err, ErrReset) { + // Pending entries reaching userspace here have already been identified as + // TUN-bound by the driver. Permit means "do not locally redirect to the + // Windows redirect listener"; the original connect continues into the TUN, + // where reset semantics are enforced by the TUN stack. + return winredirect.VerdictPermit + } + if errors.Is(err, ErrBypass) && r.logger != nil { + r.logger.Debug("bypass not supported on Windows, redirecting: ", src, " -> ", dst) + } + + r.redirectServer.connTable.Store(src, dst, ctx) + + return winredirect.VerdictRedirect +} + +func (r *autoRedirect) newConnContext(metadata *AutoRedirectMetadata) context.Context { + ctx := r.ctx + if ctx == nil { + ctx = context.Background() + } + if r.connContext != nil { + ctx = r.connContext(ctx) + } + if metadata != nil { + ctx = ContextWithAutoRedirectMetadata(ctx, metadata) + } + return ctx +} + +func (r *autoRedirect) resolveMetadata(conn *winredirect.PendingConn) *AutoRedirectMetadata { + processPath, _ := queryFullProcessImageName(conn.ProcessID) + return &AutoRedirectMetadata{ + ProcessID: conn.ProcessID, + ProcessPath: processPath, + UserId: -1, + } +} + +func (r *autoRedirect) updateLocalAddresses() { + if r.interfaceFinder == nil { + return + } + r.interfaceFinder.Update() + newLocalAddresses := common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix { + return common.Filter(it.Addresses, func(prefix netip.Prefix) bool { + return it.Name == "Loopback Pseudo-Interface 1" || prefix.Addr().IsGlobalUnicast() + }) + }) + r.localAddressMu.Lock() + defer r.localAddressMu.Unlock() + if slices.Equal(newLocalAddresses, r.localAddresses) { + return + } + r.localAddresses = newLocalAddresses + if r.logger != nil { + r.logger.Debug("updating local address set to [", strings.Join(common.Map(newLocalAddresses, func(it netip.Prefix) string { + return it.String() + }), ", ")+"]") + } +} + +func (r *autoRedirect) isLocalAddress(addr netip.Addr) bool { + r.localAddressMu.RLock() + defer r.localAddressMu.RUnlock() + for _, prefix := range r.localAddresses { + if prefix.Contains(addr) { + return true + } + } + return false +} + +func (r *autoRedirect) isDisabledFamily(addr netip.Addr) bool { + if addr.Is4() { + return !r.enableIPv4 + } + return !r.enableIPv6 +} + +func (r *autoRedirect) dnsServerForFamily(addr netip.Addr) netip.Addr { + isV4 := addr.Is4() + dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { + return it.Is4() == isV4 + }) + if dnsServer.IsValid() { + return dnsServer + } + if isV4 { + if len(r.tunOptions.Inet4Address) > 0 && HasNextAddress(r.tunOptions.Inet4Address[0], 1) { + return r.tunOptions.Inet4Address[0].Addr().Next() + } + } else { + if len(r.tunOptions.Inet6Address) > 0 && HasNextAddress(r.tunOptions.Inet6Address[0], 1) { + return r.tunOptions.Inet6Address[0].Addr().Next() + } + } + return netip.Addr{} +} + +func (r *autoRedirect) resolveTunInterfaceGUID() ([16]byte, error) { + var index int + if r.interfaceFinder != nil { + err := r.interfaceFinder.Update() + if err != nil && r.logger != nil { + r.logger.Debug("update interface finder: ", err) + } + iface, err := r.interfaceFinder.ByName(r.tunOptions.Name) + if err != nil && r.logger != nil { + r.logger.Debug("interface finder lookup: ", err) + } + if err == nil { + index = iface.Index + } + } + if index == 0 { + iface, err := net.InterfaceByName(r.tunOptions.Name) + if err != nil { + return [16]byte{}, err + } + index = iface.Index + } + luid, err := winipcfg.LUIDFromIndex(uint32(index)) + if err != nil { + return [16]byte{}, err + } + guid, err := luid.GUID() + if err != nil { + return [16]byte{}, err + } + return guidBytes(guid), nil +} + +func guidBytes(guid *windows.GUID) [16]byte { + return *(*[16]byte)(unsafe.Pointer(guid)) +} + +func pendingConnSrc(conn *winredirect.PendingConn) M.Socksaddr { + return M.SocksaddrFrom(pendingAddr(conn.AddressFamily, conn.SrcAddr), conn.SrcPort) +} + +func pendingConnDst(conn *winredirect.PendingConn) M.Socksaddr { + return M.SocksaddrFrom(pendingAddr(conn.AddressFamily, conn.DstAddr), conn.DstPort) +} + +func pendingAddr(af uint8, raw [16]byte) netip.Addr { + if af == winsys.AF_INET { + return netip.AddrFrom4([4]byte(raw[:4])) + } + return netip.AddrFrom16(raw) +} + +func queryFullProcessImageName(pid uint32) (string, error) { + handle, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid) + if err != nil { + return "", err + } + defer windows.CloseHandle(handle) + var buf [windows.MAX_PATH]uint16 + n := uint32(len(buf)) + err = windows.QueryFullProcessImageName(handle, 0, &buf[0], &n) + if err != nil { + return "", err + } + return windows.UTF16ToString(buf[:n]), nil +} diff --git a/stack_gvisor_icmp.go b/stack_gvisor_icmp.go index da5549b6..ed33c335 100644 --- a/stack_gvisor_icmp.go +++ b/stack_gvisor_icmp.go @@ -62,6 +62,7 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa if destinationAddr != f.inet4Address { action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) { return f.handler.PrepareConnection( + f.ctx, N.NetworkICMP, M.SocksaddrFrom(sourceAddr, 0), M.SocksaddrFrom(destinationAddr, 0), @@ -123,6 +124,7 @@ func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pa if destinationAddr != f.inet6Address { action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) { return f.handler.PrepareConnection( + f.ctx, N.NetworkICMP, M.SocksaddrFrom(sourceAddr, 0), M.SocksaddrFrom(destinationAddr, 0), diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index 0c63ee11..82327509 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -79,7 +79,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) { source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort) destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort) - _, pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination, nil, 0) + _, pErr := f.handler.PrepareConnection(f.ctx, N.NetworkTCP, source, destination, nil, 0) if pErr != nil { r.Complete(!errors.Is(pErr, ErrDrop)) return diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 2e8ff3e7..1a21c19d 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -58,7 +58,7 @@ func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac func rangeIterate(r stack.Range, fn func(*buffer.View)) func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { - _, pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination, nil, 0) + _, pErr := f.handler.PrepareConnection(f.ctx, N.NetworkUDP, source, destination, nil, 0) if pErr != nil { if !errors.Is(pErr, ErrDrop) { gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer)) diff --git a/stack_system.go b/stack_system.go index e2cdd45e..84127470 100644 --- a/stack_system.go +++ b/stack_system.go @@ -399,7 +399,7 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err } } if !loopback { - natPort, err := s.tcpNat.Lookup(source, destination, s.handler) + natPort, err := s.tcpNat.Lookup(s.ctx, source, destination, s.handler) if err != nil { if errors.Is(err, ErrDrop) { return false, nil @@ -494,7 +494,7 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err } } if !loopback { - natPort, err := s.tcpNat.Lookup(source, destination, s.handler) + natPort, err := s.tcpNat.Lookup(s.ctx, source, destination, s.handler) if err != nil { if errors.Is(err, ErrDrop) { return false, nil @@ -589,7 +589,7 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error { } func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) { - _, pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination, nil, 0) + _, pErr := s.handler.PrepareConnection(s.ctx, N.NetworkUDP, source, destination, nil, 0) if pErr != nil { if !errors.Is(pErr, ErrDrop) { if source.IsIPv4() { @@ -640,6 +640,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) (bool if destinationAddr != s.inet4Address { action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) { return s.handler.PrepareConnection( + s.ctx, N.NetworkICMP, M.SocksaddrFrom(sourceAddr, 0), M.SocksaddrFrom(destinationAddr, 0), @@ -715,6 +716,7 @@ func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) (bool if destinationAddr != s.inet6Address { action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) { return s.handler.PrepareConnection( + s.ctx, N.NetworkICMP, M.SocksaddrFrom(sourceAddr, 0), M.SocksaddrFrom(destinationAddr, 0), diff --git a/stack_system_nat.go b/stack_system_nat.go index 636b561d..0ab62beb 100644 --- a/stack_system_nat.go +++ b/stack_system_nat.go @@ -80,14 +80,14 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession { return session } -func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handler Handler) (uint16, error) { +func (n *TCPNat) Lookup(ctx context.Context, source netip.AddrPort, destination netip.AddrPort, handler Handler) (uint16, error) { n.addrAccess.RLock() port, loaded := n.addrMap[source] n.addrAccess.RUnlock() if loaded { return port, nil } - _, pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), nil, 0) + _, pErr := handler.PrepareConnection(ctx, N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), nil, 0) if pErr != nil { return 0, pErr } diff --git a/tun.go b/tun.go index 35cd0956..608e46f1 100644 --- a/tun.go +++ b/tun.go @@ -1,6 +1,7 @@ package tun import ( + "context" "io" "net" "net/netip" @@ -20,6 +21,7 @@ import ( type Handler interface { PrepareConnection( + ctx context.Context, network string, source M.Socksaddr, destination M.Socksaddr, @@ -102,6 +104,8 @@ type Options struct { IncludeAndroidUser []int IncludePackage []string ExcludePackage []string + IncludeMACAddress []net.HardwareAddr + ExcludeMACAddress []net.HardwareAddr InterfaceFinder control.InterfaceFinder InterfaceMonitor DefaultInterfaceMonitor FileDescriptor int