// Copyright 2025 David Ireland, DI Management Services Pty Limited
// <www.di-mgt.com.au> <www.cryptosys.net>.
// Use of this source code is governed by the MIT License
// available at https://opensource.org/license/mit
// $Date: 2025-07-27 05:17Z

// Interface to myTestDll.dll
// This code demonstrates how to call functions from a Windows DLL using Go.
// It includes examples of calling functions with different parameter types
// and handling string conversions and negative integers.
package main

/* ``myTestDll.dll`` function signatures:
LONG __stdcall MyVoidFunc(void);
LONG __stdcall MyIntFunc(LONG n);
LONG __stdcall MyStringFunc(LPSTR szOutput, DWORD nOutChars, LPCSTR szInput, DWORD nOptions);
LONG __stdcall MyByteFunc(BYTE *lpOutput, DWORD nOutBytes, CONST BYTE *lpInput, DWORD nInBytes, DWORD nOptions);
LONG __stdcall MyUnicodeFunc(LPWSTR wsOutput, DWORD nOutChars, LPCWSTR wsInput, DWORD nOptions);
*/

import (
    "encoding/hex"
    "fmt"
    "syscall"
    "unicode/utf16"
    "unsafe"
)

// INTERNAL UTILS
// stringToCharPtr and stringToUTF16Ptr by Justen Walker
// https://medium.com/@justen.walker/breaking-all-the-rules-using-go-to-call-windows-api-2cbfd8c79724
// https://gist.github.com/justenwalker/d2fa7c80e6454bf7d5314ee2f28d1b00#file-string_convert_windows-go

// StringToCharPtr converts a Go string into pointer to a null-terminated cstring.
func stringToCharPtr(str string) *uint8 {
    chars := append([]byte(str), 0) // null terminated
    return &chars[0]
}

// StringToUTF16Ptr converts a Go string into a pointer to a null-terminated UTF-16 wide string.
func stringToUTF16Ptr(str string) *uint16 {
    wchars := utf16.Encode([]rune(str + "\x00"))
    return &wchars[0]
}

// utf16PtrToString converts a pointer to a UTF-16 encoded string to a Go string.
// It reads the UTF-16 encoded characters until it encounters a null character.
// It assumes the string is null-terminated.
// This function is useful for converting strings returned from Windows API calls that use UTF-16 encoding.
// It returns an empty string if the pointer is nil.
func utf16PtrToString(ptr *uint16) string {
    if ptr == nil {
        return ""
    }
    var runes []rune
    for {
        r := *ptr
        if r == 0 {
            break
        }
        runes = append(runes, rune(r))
        ptr = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + unsafe.Sizeof(*ptr)))
    }
    return string(runes)
}

// END INTERNAL UTILS

func main() {
    // 1. Calling a function with no parameters
    // The function signature is:
    //   LONG __stdcall MyVoidFunc(void);
    var proc = syscall.NewLazyDLL("myTestDll.dll").NewProc("MyVoidFunc")
    ret, _, _ := proc.Call(0) // IMPORTANT: Must have at least one parameter, which is ignored but required by syscall.Call
    // that is, use `proc.Call(0)` to call a function with no parameters, not `proc.Call()`
    fmt.Printf("MyVoidFunc returns %d\n", ret) // 42
    // Note that `ret` is of type `uintptr`, which is an unsigned integer type that can hold a pointer.
    // It cannot be used as an `int` directly, for example
    fmt.Println("-ret=", -ret, "!!!") // 18446744073709551574, uh!
    // but since the function returns a `LONG`, which is a 32-bit signed integer,
    // we can safely convert it to `int`.
    i := int(-ret)
    fmt.Println("int(-ret)=", i) // -42, as expected

    // 2. Calling a function with an int32 input parameter
    // The function signature is:
    //   LONG __stdcall MyIntFunc(LONG n);
    proc = syscall.NewLazyDLL("myTestDll.dll").NewProc("MyIntFunc")
    ret, _, _ = proc.Call(0)                     // Note we can pass a bare '0' for zero
    fmt.Printf("MyIntFunc(0) returns %d\n", ret) // 0
    n := 888
    // But for a nonzero input, we must pass a uintptr
    ret, _, _ = proc.Call(uintptr(n))
    fmt.Printf("MyIntFunc(n) returns %d\n", ret) // 888
    n = -123
    // And for a negative integer input, we must first cast it to int32
    ret, _, _ = proc.Call(uintptr(int32(n)))
    // and then cast the negative return value
    fmt.Printf("MyIntFunc(-n) returns %d\n", int32(ret)) // -123

    // 3. Calling a function that outputs to an ANSI string with ANSI string input
    // The function signature is:
    //   LONG __stdcall MyStringFunc(LPSTR szOutput, DWORD nOutChars, LPCSTR szInput, DWORD nOptions);
    inputstr := "Hello World!"
    fmt.Printf("MyStringFunc input is \"%s\"\n", inputstr)
    proc = syscall.NewLazyDLL("myTestDll.dll").NewProc("MyStringFunc")
    // 3a. Call the function with NULL output to find required length (excluding the terminating null)
    nchars, _, _ := proc.Call(0, 0, // Note we can pass a bare '0' for NULL and for zero
        uintptr(unsafe.Pointer(stringToCharPtr(inputstr))),
        0)
    fmt.Printf("MyStringFunc returns %d (expected 12)\n", nchars)
    if int(nchars) < 0 {
        panic("MyStringFunc returned an error")
    }
    // 3b. Allocate buffer for ANSI string output as a *byte* array (uint8)
    // plus add an extra one for the terminating null character
    strbuf := make([]byte, nchars+1)
    // 3c. Call the function again to get the output in the output buffer
    nchars, _, _ = proc.Call(
        uintptr(unsafe.Pointer(&strbuf[0])),
        nchars,
        uintptr(unsafe.Pointer(stringToCharPtr(inputstr))),
        0)
    // 3d. Trim the new output to remove the trailing null byte, and convert to a golang string type
    outstr := string(strbuf[:nchars])
    fmt.Printf("MyStringFunc output is \"%s\"\n", outstr) // "Hello World!"

    // 4. Calling a function that outputs to a byte array with a byte array input.
    // The function signature is:
    //  LONG __stdcall MyByteFunc(BYTE *lpOutput, DWORD nOutBytes, CONST BYTE *lpInput, DWORD nInBytes, DWORD nOptions);
    proc = syscall.NewLazyDLL("myTestDll.dll").NewProc("MyByteFunc")
    inbytes := []byte{0xde, 0xad, 0xbe, 0xef}
    fmt.Print("MyByteFunc input is (0x)", hex.EncodeToString(inbytes), "\n")
    flags := 0xFEFF // nOptions flags
    // 4a. Call the function with NULL output to find required length
    nbytes, _, _ := proc.Call(0, 0,
        uintptr(unsafe.Pointer(&inbytes[0])),
        uintptr(len(inbytes)),
        uintptr(flags))
    fmt.Printf("MyByteFunc returns %d (expected 4)\n", nbytes)
    if int(nbytes) < 0 {
        panic("MyByteFunc returned an error")
    }
    // 4b. Allocate buffer for byte array output of exact required length
    bbuf := make([]byte, nbytes)
    // 4c. Call the function again to work on the input
    nbytes, _, _ = proc.Call(
        uintptr(unsafe.Pointer(&bbuf[0])),
        nbytes,
        uintptr(unsafe.Pointer(&inbytes[0])),
        uintptr(len(inbytes)),
        uintptr(flags))
    if int(nbytes) < 0 {
        panic("MyByteFunc returned an error")
    }
    // The byte array output is ready to work with
    fmt.Print("MyByteFunc output is (0x)", hex.EncodeToString(bbuf), "\n") // (0x)deadbeef

    // 5. Calling a function that outputs to a UTF-16 (Unicode) string with a UTF-16 string input
    // The function signature is:
    //   LONG __stdcall MyUnicodeFunc(LPWSTR wsOutput, DWORD nOutChars, LPCWSTR wsInput, DWORD nOptions)
    proc = syscall.NewLazyDLL("myTestDll.dll").NewProc("MyUnicodeFunc")
    wstr := "Unicode: Привет 世界" // Hello World in Russian and Chinese
    fmt.Printf("MyUnicodeFunc input is \"%s\"\n", wstr)
    // 5a. Call the function with NULL output to find required length
    flags = 888
    nwchars, _, _ := proc.Call(0, 0,
        uintptr(unsafe.Pointer(stringToUTF16Ptr(wstr))),
        uintptr(flags))
    fmt.Printf("MyUnicodeFunc returns %d (expected 18)\n", nwchars) // Expected 18
    if int(nwchars) < 0 {
        panic("MyUnicodeFunc returned an error")
    }
    // 5b. Allocate buffer for UTF16 string output including the terminating null character
    wstrbuf := make([]uint16, nwchars+1) // +1 for the null terminator
    // 5c. Call the function again to work on the input string
    _, _, _ = proc.Call(
        uintptr(unsafe.Pointer(&wstrbuf[0])),
        nwchars,
        uintptr(unsafe.Pointer(stringToUTF16Ptr(wstr))),
        uintptr(flags))
    // 5d. Convert to a golang string type
    woutstr := utf16PtrToString(&wstrbuf[0])
    fmt.Printf("MyUnicodeFunc output is \"%s\"\n", woutstr)

    fmt.Println("\nALL DONE.")
}