diff --git a/devtools/deploy_mysqld.py b/devtools/deploy_mysqld.py index 2fa5d03..499797f 100644 --- a/devtools/deploy_mysqld.py +++ b/devtools/deploy_mysqld.py @@ -8,6 +8,11 @@ from dataclasses import dataclass, asdict from testcontainers.mysql import MySqlContainer from docker.models.containers import Container +from testcontainers.core.config import testcontainers_config as c + +from sqlalchemy import create_engine + +c.ryuk_disabled = True @click.group() @@ -80,7 +85,6 @@ def mDeploy(version): ) return - os.environ["TESTCONTAINERS_RYUK_DISABLED"] = "true" mContainer = MySqlContainer(f"mysql:{version}") datadir = os.getcwd() + f"/datadir/{version}" mContainer.with_volume_mapping(datadir, "/var/lib/mysql", "rw") @@ -105,14 +109,34 @@ def deploy(version): @main.command() @click.option("--version", type=click.STRING, default="8.0.17") -def connect(version): +@click.option("--sql", type=click.STRING, default="") +def connect(version, sql): deploy_container = load_deploy() if version not in deploy_container: mDeploy(version) deploy_container = load_deploy() - os.system(deploy_container.get(version).cmd) + if sql == "": + os.system(deploy_container.get(version).cmd) + else: + os.system(deploy_container.get(version).cmd + f" -e '{sql}'") +@main.command() +@click.option("--version", type=click.STRING, default="") +@click.option("--sql", type=click.STRING, default="") +@click.option("--file", type=click.STRING, default="") +def exec(version, sql, file): + deploy_container = load_deploy() + if version not in deploy_container: + mDeploy(version) + deploy_container = load_deploy() + engine = create_engine(deploy_container.get(version).url) + if file != "": + with open(file, "r") as f: + sql = f.read() + with engine.connect() as conn: + result = conn.exec_driver_sql(sql) + print(result.all()[0][1]) if __name__ == "__main__": main() diff --git a/src/pyinnodb/cli/sql.py b/src/pyinnodb/cli/sql.py index 19bf52e..ee6c57a 100644 --- a/src/pyinnodb/cli/sql.py +++ b/src/pyinnodb/cli/sql.py @@ -15,7 +15,8 @@ @click.pass_context @click.option("--mode", type=click.Choice(["sdi", "ddl", "dump"]), default="ddl") @click.option("--sdi-idx", type=click.INT, default=0) -def tosql(ctx, mode, sdi_idx): +@click.option("--schema/--no-schema", default=True) +def tosql(ctx, mode, sdi_idx, schema): """dump the ddl/dml/sdi of the ibd table ddl) output the create table ddl; @@ -35,7 +36,10 @@ def tosql(ctx, mode, sdi_idx): elif mode == "ddl": table_object = Table(**sdi_page.ddl(f, sdi_idx)["dd_object"]) - table_name = f"`{table_object.schema_ref}`.`{table_object.name}`" + if schema: + table_name = f"`{table_object.schema_ref}`.`{table_object.name}`" + else: + table_name = f"`{table_object.name}`" columns_dec = [] for c in table_object.columns: if ( @@ -55,7 +59,7 @@ def tosql(ctx, mode, sdi_idx): columns_dec.extend(constraints) foreign_keys = table_object.gen_foreign_key() columns_dec.extend(foreign_keys) - columns_dec = "\n " + ",\n ".join(columns_dec) + "\n" + columns_dec = "\n " + ",\n ".join(columns_dec) + "\n" table_collation = const.get_collation_by_id(table_object.collation_id) parts = table_object.gen_sql_for_partition() desc = f"ENGINE={table_object.engine} DEFAULT CHARSET={table_collation.CHARACTER_SET_NAME} COLLATE={table_collation.COLLATION_NAME}" @@ -65,7 +69,7 @@ def tosql(ctx, mode, sdi_idx): else "" ) print( - f"CREATE TABLE {table_name} ({columns_dec}) {desc} {chr(10)+parts if parts else ''}{comment}" + f"CREATE TABLE {table_name} ({columns_dec}) {desc}{parts}{comment}" ) else: table_object = Table(**sdi_page.ddl(f, sdi_idx)["dd_object"]) diff --git a/src/pyinnodb/const/dd_column_type.py b/src/pyinnodb/const/dd_column_type.py index c96a574..2e86210 100644 --- a/src/pyinnodb/const/dd_column_type.py +++ b/src/pyinnodb/const/dd_column_type.py @@ -58,6 +58,11 @@ def is_var(cls, t, mysqld_version=None): else: return True + @classmethod + def is_string(cls, t): + tt = cls(t) + return tt in _string_type + @classmethod def is_big(cls, t): return cls(t) in _big_type @@ -92,6 +97,16 @@ def is_big(cls, t): DDColumnType.VECTOR, ] +_string_type = [ + DDColumnType.VARCHAR, + DDColumnType.STRING, + DDColumnType.VAR_STRING, + DDColumnType.BLOB, + DDColumnType.MEDIUM_BLOB, + DDColumnType.LONG_BLOB, + DDColumnType.TINY_BLOB, +] + _big_type = [ DDColumnType.MEDIUM_BLOB, DDColumnType.LONG_BLOB, diff --git a/src/pyinnodb/sdi/table.py b/src/pyinnodb/sdi/table.py index 285397f..97b5cc8 100644 --- a/src/pyinnodb/sdi/table.py +++ b/src/pyinnodb/sdi/table.py @@ -205,7 +205,12 @@ def get_collation(self): return coll def gen_sql(self): - sql = f"`{self.name}` {self.column_type_utf8}{'' if self.is_nullable or self.is_virtual else ' NOT NULL'}" + sql = f"`{self.name}` {self.column_type_utf8}" + if self.collation_id != 255 and self.collation_id != 63: + collation = const.get_collation_by_id(self.collation_id) + if DDColumnType.is_string(self.type): + pass + sql += f"{'' if self.is_nullable or self.is_virtual else ' NOT NULL'}" sql += f"{' AUTO_INCREMENT' if self.is_auto_increment else ''}" if self.default_option != "": sql += f" DEFAULT ({self.default_option})" @@ -1024,18 +1029,20 @@ def gen_sql_for_partition(self) -> str: f"PARTITION {par.name} VALUES LESS THAN ({par.description_utf8})" ) parts = ",\n ".join(parts) + "\n" - return f"{p}{parts})*/" + return "\n" + f"{p}{parts})*/" elif pt == const.partition.PartitionType.PT_HASH: - return f"/*!50100 PARTITION BY HASH ({self.partition_expression_utf8}) PARTITIONS ({len(self.partitions)})*/" + return "\n" + f"/*!50100 PARTITION BY HASH ({self.partition_expression_utf8}) PARTITIONS ({len(self.partitions)})*/" elif pt == const.partition.PartitionType.PT_KEY_55: - return f"/*!50100 PARTITION BY KEY ({self.partition_expression_utf8}) PARTITIONS ({len(self.partitions)})*/" + return "\n" + f"/*!50100 PARTITION BY KEY ({self.partition_expression_utf8}) PARTITIONS ({len(self.partitions)})*/" elif pt == const.partition.PartitionType.PT_LIST: p = f"/*!50100 PARTITION BY LIST ({self.partition_expression_utf8}) (\n " parts = [] for par in self.partitions: parts.append(f"PARTITION {par.name} VALUES IN ({par.description_utf8})") parts = ",\n ".join(parts) + "\n" - return f"{p}{parts})*/" + return "\n" + f"{p}{parts})*/" + else: + return "" def should_ext():