Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions include/dxc/WinAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -916,19 +916,35 @@ unsigned int SysStringLen(const BSTR bstrString);
// RAII style mechanism for setting/unsetting a locale for the specified Windows
// codepage
class ScopedLocale {
const char *m_prevLocale;
locale_t Utf8Locale = nullptr;
locale_t PrevLocale = nullptr;

public:
explicit ScopedLocale(uint32_t codePage)
: m_prevLocale(setlocale(LC_ALL, nullptr)) {
assert((codePage == CP_UTF8) &&
explicit ScopedLocale(uint32_t CodePage) {
assert((CodePage == CP_UTF8) &&
"Support for Linux only handles UTF8 code pages");
setlocale(LC_ALL, "en_US.UTF-8");
Utf8Locale = newlocale(LC_CTYPE_MASK, "C.UTF-8", NULL);
if (!Utf8Locale)
Utf8Locale = newlocale(LC_CTYPE_MASK, "C.utf8", NULL);
if (!Utf8Locale)
Utf8Locale = newlocale(LC_CTYPE_MASK, "en_US.UTF-8", NULL);
assert(Utf8Locale && "Failed to create UTF-8 locale");
if (!Utf8Locale)
return;
PrevLocale = uselocale(Utf8Locale);
assert(PrevLocale && "Failed to set locale to UTF-8");
if (!PrevLocale) {
freelocale(Utf8Locale);
Utf8Locale = nullptr;
}
}
~ScopedLocale() {
if (m_prevLocale != nullptr) {
setlocale(LC_ALL, m_prevLocale);
}
if (PrevLocale != nullptr)
uselocale(PrevLocale);
if (Utf8Locale)
freelocale(Utf8Locale);
PrevLocale = nullptr;
Utf8Locale = nullptr;
}
};

Expand Down
131 changes: 90 additions & 41 deletions lib/DxcSupport/Unicode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
const char *lpMultiByteStr, int cbMultiByte,
wchar_t *lpWideCharStr, int cchWideChar) {

if (cbMultiByte == 0) {
// Check for invalid sizes or potential overflow.
if (cbMultiByte == 0 || cbMultiByte < -1 || cbMultiByte > (INT32_MAX - 1) ||
cchWideChar < 0 || cchWideChar > (INT32_MAX - 1)) {
Comment on lines +32 to +33
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't cbMultiByte > (INT32_MAX - 1) equivalent to cbMultiByte == INT32_MAX? Unless we were to change cbMultiByte and cchWideChar to larger types I think the equality check is clearer, no?

SetLastError(ERROR_INVALID_PARAMETER);
return 0;
}
Expand All @@ -42,18 +44,17 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
++cbMultiByte;
}
// If zero is given as the destination size, this function should
// return the required size (including the null-terminating character).
// return the required size (including or excluding the null-terminating
// character depending on whether the input included the null-terminator).
// This is the behavior of mbstowcs when the target is null.
if (cchWideChar == 0) {
lpWideCharStr = nullptr;
} else if (cchWideChar < cbMultiByte) {
SetLastError(ERROR_INSUFFICIENT_BUFFER);
return 0;
}

ScopedLocale utf8_locale_scope(CP_UTF8);

bool isNullTerminated = false;
size_t rv;
const char *prevLocale = setlocale(LC_ALL, nullptr);
setlocale(LC_ALL, "en_US.UTF-8");
if (lpMultiByteStr[cbMultiByte - 1] != '\0') {
char *srcStr = (char *)malloc((cbMultiByte + 1) * sizeof(char));
strncpy(srcStr, lpMultiByteStr, cbMultiByte);
Expand All @@ -62,14 +63,29 @@ int MultiByteToWideChar(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
free(srcStr);
} else {
rv = mbstowcs(lpWideCharStr, lpMultiByteStr, cchWideChar);
isNullTerminated = true;
}

if (rv == ~(size_t)0) {
// mbstowcs returns -1 on error.
SetLastError(ERROR_INVALID_PARAMETER);
return 0;
}

if (prevLocale)
setlocale(LC_ALL, prevLocale);
// Return value of mbstowcs (rv) excludes the terminating character.
// Matching MultiByteToWideChar requires returning the size written including
// the null terminator if the input was null-terminated, otherwise it
// returns the size written excluding the null terminator.
if (isNullTerminated)
rv += 1;

// Check for overflow when returning the size.
if (rv >= INT32_MAX) {
SetLastError(ERROR_INVALID_PARAMETER);
return 0; // Overflow error
}

if (rv == (size_t)cbMultiByte)
return rv;
return rv + 1; // mbstowcs excludes the terminating character
return rv;
}

// WideCharToMultiByte is a Windows-specific method.
Expand All @@ -84,7 +100,9 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
*lpUsedDefaultChar = FALSE;
}

if (cchWideChar == 0) {
// Check for invalid sizes or potential overflow.
if (cchWideChar == 0 || cchWideChar < -1 || cchWideChar > (INT32_MAX - 1) ||
cbMultiByte < 0 || cbMultiByte > (INT32_MAX - 1)) {
SetLastError(ERROR_INVALID_PARAMETER);
return 0;
}
Expand All @@ -98,18 +116,17 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
++cchWideChar;
}
// If zero is given as the destination size, this function should
// return the required size (including the null-terminating character).
// return the required size (including or excluding the null-terminating
// character depending on whether the input included the null-terminator).
// This is the behavior of wcstombs when the target is null.
if (cbMultiByte == 0) {
lpMultiByteStr = nullptr;
} else if (cbMultiByte < cchWideChar) {
SetLastError(ERROR_INSUFFICIENT_BUFFER);
return 0;
}

ScopedLocale utf8_locale_scope(CP_UTF8);

bool isNullTerminated = false;
size_t rv;
const char *prevLocale = setlocale(LC_ALL, nullptr);
setlocale(LC_ALL, "en_US.UTF-8");
if (lpWideCharStr[cchWideChar - 1] != L'\0') {
wchar_t *srcStr = (wchar_t *)malloc((cchWideChar + 1) * sizeof(wchar_t));
wcsncpy(srcStr, lpWideCharStr, cchWideChar);
Expand All @@ -118,21 +135,40 @@ int WideCharToMultiByte(uint32_t /*CodePage*/, uint32_t /*dwFlags*/,
free(srcStr);
} else {
rv = wcstombs(lpMultiByteStr, lpWideCharStr, cbMultiByte);
isNullTerminated = true;
}

if (rv == ~(size_t)0) {
// wcstombs returns -1 on error.
SetLastError(ERROR_INVALID_PARAMETER);
return 0;
}

if (prevLocale)
setlocale(LC_ALL, prevLocale);
// Return value of wcstombs (rv) excludes the terminating character.
// Matching MultiByteToWideChar requires returning the size written including
// the null terminator if the input was null-terminated, otherwise it
// returns the size written excluding the null terminator.
if (isNullTerminated)
rv += 1;

// Check for overflow when returning the size.
if (rv >= INT32_MAX) {
SetLastError(ERROR_INVALID_PARAMETER);
return 0; // Overflow error
}

if (rv == (size_t)cchWideChar)
return rv;
return rv + 1; // mbstowcs excludes the terminating character
return rv;
}
#endif // _WIN32

namespace Unicode {

bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp,
DWORD flags, std::string *pValue, bool *lossy) {
DXASSERT_NOMSG(cWide == ~(size_t)0 || cWide < INT32_MAX);
if (text == nullptr || pValue == nullptr || cWide == 0 || cWide >= INT32_MAX)
return false;

BOOL usedDefaultChar;
LPBOOL pUsedDefaultChar = (lossy == nullptr) ? nullptr : &usedDefaultChar;
if (lossy != nullptr)
Expand All @@ -147,31 +183,37 @@ bool WideToEncodedString(const wchar_t *text, size_t cWide, DWORD cp,
return true;
}

int cbUTF8 = ::WideCharToMultiByte(cp, flags, text, cWide, nullptr, 0,
nullptr, pUsedDefaultChar);
int cbUTF8 = ::WideCharToMultiByte(cp, flags, text, static_cast<int>(cWide),
nullptr, 0, nullptr, pUsedDefaultChar);
if (cbUTF8 == 0)
return false;

pValue->resize(cbUTF8);

cbUTF8 = ::WideCharToMultiByte(cp, flags, text, cWide, &(*pValue)[0],
pValue->size(), nullptr, pUsedDefaultChar);
cbUTF8 = ::WideCharToMultiByte(cp, flags, text, static_cast<int>(cWide),
&(*pValue)[0], pValue->size(), nullptr,
pUsedDefaultChar);
DXASSERT(cbUTF8 > 0, "otherwise contents have changed");
DXASSERT((*pValue)[pValue->size()] == '\0',
"otherwise string didn't null-terminate after resize() call");
if ((cWide == ~(size_t)0 || text[cWide - 1] == L'\0') &&
(*pValue)[pValue->size() - 1] == '\0') {
// When the input is null-terminated, the output includes the null
// terminator. Reduce the size by 1 to remove the embedded null terminator
// inside the string.
pValue->resize(cbUTF8 - 1);
}

if (lossy != nullptr)
*lossy = usedDefaultChar;
return true;
}

bool UTF8ToWideString(const char *pUTF8, std::wstring *pWide) {
size_t cbUTF8 = (pUTF8 == nullptr) ? 0 : strlen(pUTF8);
return UTF8ToWideString(pUTF8, cbUTF8, pWide);
return UTF8ToWideString(pUTF8, -1, pWide);
}

bool UTF8ToWideString(const char *pUTF8, size_t cbUTF8, std::wstring *pWide) {
DXASSERT_NOMSG(pWide != nullptr);
DXASSERT_NOMSG(cbUTF8 == ~(size_t)0 || cbUTF8 < INT32_MAX);

// Handle zero-length as a special case; it's a special value to indicate
// errors in MultiByteToWideChar.
Expand All @@ -181,17 +223,23 @@ bool UTF8ToWideString(const char *pUTF8, size_t cbUTF8, std::wstring *pWide) {
}

int cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8,
cbUTF8, nullptr, 0);
static_cast<int>(cbUTF8), nullptr, 0);
if (cWide == 0)
return false;

pWide->resize(cWide);

cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8, cbUTF8,
&(*pWide)[0], pWide->size());
cWide = ::MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, pUTF8,
static_cast<int>(cbUTF8), &(*pWide)[0],
pWide->size());
DXASSERT(cWide > 0, "otherwise contents changed");
DXASSERT((*pWide)[pWide->size()] == L'\0',
"otherwise wstring didn't null-terminate after resize() call");
if ((cbUTF8 == ~(size_t)0 || pUTF8[cbUTF8 - 1] == '\0') &&
(*pWide)[pWide->size() - 1] == '\0') {
// When the input is null-terminated, the output includes the null
// terminator. Reduce the size by 1 to remove the embedded null terminator
// inside the string.
pWide->resize(cWide - 1);
}
return true;
}

Expand All @@ -213,11 +261,12 @@ bool UTF8ToConsoleString(const char *text, size_t textLen, std::string *pValue,
if (!UTF8ToWideString(text, textLen, &text16)) {
return false;
}
return WideToConsoleString(text16.c_str(), text16.length(), pValue, lossy);
return WideToConsoleString(text16.c_str(), text16.length() + 1, pValue,
lossy);
}

bool UTF8ToConsoleString(const char *text, std::string *pValue, bool *lossy) {
return UTF8ToConsoleString(text, strlen(text), pValue, lossy);
return UTF8ToConsoleString(text, ~(size_t)0, pValue, lossy);
}

bool WideToConsoleString(const wchar_t *text, size_t textLen,
Expand All @@ -230,7 +279,7 @@ bool WideToConsoleString(const wchar_t *text, size_t textLen,

bool WideToConsoleString(const wchar_t *text, std::string *pValue,
bool *lossy) {
return WideToConsoleString(text, wcslen(text), pValue, lossy);
return WideToConsoleString(text, ~(size_t)0, pValue, lossy);
}

bool WideToUTF8String(const wchar_t *pWide, size_t cWide, std::string *pUTF8) {
Expand All @@ -242,7 +291,7 @@ bool WideToUTF8String(const wchar_t *pWide, size_t cWide, std::string *pUTF8) {
bool WideToUTF8String(const wchar_t *pWide, std::string *pUTF8) {
DXASSERT_NOMSG(pWide != nullptr);
DXASSERT_NOMSG(pUTF8 != nullptr);
return WideToEncodedString(pWide, wcslen(pWide), CP_UTF8, 0, pUTF8, nullptr);
return WideToEncodedString(pWide, ~(size_t)0, CP_UTF8, 0, pUTF8, nullptr);
}

std::string WideToUTF8StringOrThrow(const wchar_t *pWide) {
Expand Down
66 changes: 33 additions & 33 deletions tools/clang/unittests/HLSL/CompilerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ class CompilerTest : public ::testing::Test {
void TestEncodingImpl(const void *sourceData, size_t sourceSize,
UINT32 codePage, const void *includedData,
size_t includedSize, const WCHAR *encoding = nullptr);
template <typename T1, typename T2>
void TestEncodingImpl(std::basic_string<T1> source, UINT32 codePage,
std::basic_string<T2> included,
const WCHAR *encoding = nullptr) {
TestEncodingImpl(source.data(), source.size() * sizeof(T1), codePage,
included.data(), included.size() * sizeof(T2), encoding);
}
TEST_METHOD(CompileWithEncodeFlagTestSource)

#if _ITERATOR_DEBUG_LEVEL == 0
Expand Down Expand Up @@ -3636,54 +3643,47 @@ void CompilerTest::TestEncodingImpl(const void *sourceData, size_t sourceSize,

TEST_F(CompilerTest, CompileWithEncodeFlagTestSource) {

std::string sourceUtf8 = "#include \"include.hlsl\"\r\n"
"float4 main() : SV_Target { return 0; }";
std::string includeUtf8 = "// Comment\n";
std::string SourceUtf8 = "#include \"include.hlsl\"\n"
"float4 main() : SV_Target { return Buf[0]; }";
std::string IncludeUtf8 = "Buffer<float4> Buf;\n";
std::string utf8BOM = "\xEF"
"\xBB"
"\xBF"; // UTF-8 BOM
std::string includeUtf8BOM = utf8BOM + includeUtf8;
std::string IncludeUtf8BOM = utf8BOM + IncludeUtf8;

std::wstring sourceWide = L"#include \"include.hlsl\"\r\n"
L"float4 main() : SV_Target { return 0; }";
std::wstring includeWide = L"// Comments\n";
std::wstring utf16BOM = L"\xFEFF"; // UTF-16 LE BOM
std::wstring includeUtf16BOM = utf16BOM + includeWide;
std::wstring SourceWide = L"#include \"include.hlsl\"\n"
L"float4 main() : SV_Target { return Buf[0]; }";
std::wstring IncludeWide = L"Buffer<float4> Buf;\n";

// Included files interpreted with encoding option if no BOM
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
includeUtf8.data(), includeUtf8.size(), L"utf8");
// Windows: UTF-16 BOM is '\xFEFF'
// *nix: UTF-32 BOM is L'\x0000FEFF'
// Thus, BOM wide character value is identical for UTF-16 and UTF-32.
// Endianess will be native, since we are using wide strings directly.
std::wstring WideBOM = L"\xFEFF";

std::wstring IncludeWideBOM = WideBOM + IncludeWide;

TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
DXC_CP_WIDE, includeWide.data(),
includeWide.size() * sizeof(L'A'), L"wide");
// Included files interpreted with encoding option if no BOM
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeUtf8, L"utf8");
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeWide, L"wide");

// Encoding option ignored if BOM present
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
includeUtf8BOM.data(), includeUtf8BOM.size(), L"wide");
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeUtf8BOM, L"wide");
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeWideBOM, L"utf8");

TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
DXC_CP_WIDE, includeUtf16BOM.data(),
includeUtf16BOM.size() * sizeof(L'A'), L"utf8");
// Encoding option ignored if BOM present - different encoding for source
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeUtf8BOM, L"wide");
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeWideBOM, L"utf8");

// Source file interpreted according to DxcBuffer encoding if not CP_ACP
// Included files interpreted with encoding option if no BOM
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_UTF8,
includeWide.data(), includeWide.size() * sizeof(L'A'),
L"wide");

TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
DXC_CP_WIDE, includeUtf8.data(), includeUtf8.size(),
L"utf8");
TestEncodingImpl(SourceUtf8, DXC_CP_UTF8, IncludeWide, L"wide");
TestEncodingImpl(SourceWide, DXC_CP_WIDE, IncludeUtf8, L"utf8");

// Source file interpreted by encoding option if source DxcBuffer encoding =
// CP_ACP (default)
TestEncodingImpl(sourceUtf8.data(), sourceUtf8.size(), DXC_CP_ACP,
includeUtf8.data(), includeUtf8.size(), L"utf8");

TestEncodingImpl(sourceWide.data(), sourceWide.size() * sizeof(L'A'),
DXC_CP_ACP, includeWide.data(),
includeWide.size() * sizeof(L'A'), L"wide");
TestEncodingImpl(SourceUtf8, DXC_CP_ACP, IncludeUtf8, L"utf8");
TestEncodingImpl(SourceWide, DXC_CP_ACP, IncludeWide, L"wide");
}

TEST_F(CompilerTest, CompileWhenODumpThenOptimizerMatch) {
Expand Down
Loading