From 4352a10ab95af975e08f3019ca07027b7070e62e Mon Sep 17 00:00:00 2001 From: Matthias Koch <23187557+matthias-koch@users.noreply.github.com> Date: Wed, 8 Nov 2023 03:47:37 +0100 Subject: [PATCH] [ZEPPELIN-5875] Add: z.show works with subtypes of DataFrame (#4683) --- .../main/resources/python/zeppelin_context.py | 2 +- .../python/BasePythonInterpreterTest.java | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/src/main/resources/python/zeppelin_context.py b/python/src/main/resources/python/zeppelin_context.py index de3807d09e9..8223966d40e 100644 --- a/python/src/main/resources/python/zeppelin_context.py +++ b/python/src/main/resources/python/zeppelin_context.py @@ -179,7 +179,7 @@ def getDefaultChecked(self, defaultChecked): def show(self, p, **kwargs): if hasattr(p, '__name__') and p.__name__ == "matplotlib.pyplot": self.show_matplotlib(p, **kwargs) - elif type(p).__name__ == "DataFrame": # does not play well with sub-classes + elif any(t.__name__ == 'DataFrame' for t in type(p).mro()): # `isinstance(p, DataFrame)` would req `import pandas.core.frame.DataFrame` # and so a dependency on pandas self.show_dataframe(p, **kwargs) diff --git a/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java b/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java index 7469fdd4740..ecb903280b5 100644 --- a/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java +++ b/python/src/test/java/org/apache/zeppelin/python/BasePythonInterpreterTest.java @@ -311,6 +311,22 @@ public void testZeppelinContext() throws InterpreterException, InterruptedExcept assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); assertEquals("id\tname\n1\ta a\n2\tb b\n3\tc c\n", interpreterResultMessages.get(0).getData()); + // Pandas DataFrame with sub type + context = getInterpreterContext(); + result = interpreter.interpret("import pandas as pd\n" + + "class ExtendedDataFrame(pd.DataFrame):\n" + + " pass\n" + + "df = ExtendedDataFrame({'id':[1,2,3], 'name':['a\ta','b\\nb','c\\r\\nc']})\n" + + "z.show(df)", + context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code(), + context.out.toInterpreterResultMessage().toString()); + interpreterResultMessages = context.out.toInterpreterResultMessage(); + assertEquals(1, interpreterResultMessages.size()); + assertEquals(InterpreterResult.Type.TABLE, interpreterResultMessages.get(0).getType()); + assertEquals("id\tname\n1\ta a\n2\tb b\n3\tc c\n", interpreterResultMessages.get(0).getData()); + + // Pandas DataFrame limited to three results context = getInterpreterContext(); result = interpreter.interpret("import pandas as pd\n" + "df = pd.DataFrame({'id':[1,2,3,4], 'name':['a','b','c', 'd']})\nz.show(df)", context);