From 91b00943814683fc05531da977fd35457b58ce21 Mon Sep 17 00:00:00 2001 From: Tony Locke Date: Thu, 15 Aug 2024 21:58:45 +0100 Subject: [PATCH] identifier() should quote if uppercase characters Previously the identifier() function wouldn't quote the identifier solely because it contained uppercase characters. The new behaviour of quoting if there are any uppercase characters is intended to be more in line with what people expect. It's also easier to recreate the old behaviour by calling lower() on the input beforehand, than it is to try and create the new behaviour with the old function. --- src/pg8000/converters.py | 8 ++++++-- test/test_converters.py | 6 +++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/pg8000/converters.py b/src/pg8000/converters.py index 73392a9..84593ff 100644 --- a/src/pg8000/converters.py +++ b/src/pg8000/converters.py @@ -772,6 +772,10 @@ def make_params(py_types, values): return tuple([make_param(py_types, v) for v in values]) +def _quote_letter(c): + return c.isupper() if c.isalpha() else True + + def identifier(sql): if not isinstance(sql, str): raise InterfaceError("identifier must be a str") @@ -779,10 +783,10 @@ def identifier(sql): if len(sql) == 0: raise InterfaceError("identifier must be > 0 characters in length") - quote = not sql[0].isalpha() + quote = _quote_letter(sql[0]) for c in sql[1:]: - if not (c.isalpha() or c.isdecimal() or c in "_$"): + if _quote_letter(c) and c not in "0123456789_$": if c == "\u0000": raise InterfaceError( "identifier cannot contain the code zero character" diff --git a/test/test_converters.py b/test/test_converters.py index 41fd5cc..4781aec 100644 --- a/test/test_converters.py +++ b/test/test_converters.py @@ -345,7 +345,11 @@ def test_identifier_quoted_null(): (" Table", '" Table"'), ("A Table", '"A Table"'), ('A " Table', '"A "" Table"'), - ("Table$", "Table$"), + ("table$", "table$"), + ("Table$", '"Table$"'), + ("tableఐ", "tableఐ"), # Unicode character 0C10 which is uncased + ("table", "table"), + ("tAble", '"tAble"'), ], ) def test_identifier_success(value, expected):