diff --git a/neanderthal/tech/v3/dataset/tribuo_test.clj b/neanderthal/tech/v3/dataset/tribuo_test.clj index a10064b6..f2152793 100644 --- a/neanderthal/tech/v3/dataset/tribuo_test.clj +++ b/neanderthal/tech/v3/dataset/tribuo_test.clj @@ -11,7 +11,6 @@ [org.tribuo.regression.xgboost XGBoostRegressionTrainer])) - (defn classification-example-ds [x] (let [x (if (integer? x) @@ -80,3 +79,11 @@ (is (= "class org.tribuo.classification.dtree.CARTClassificationTrainer" (str (class trainer)))))) + +(deftest test-keyword-name + (testing "string name (OK)" + (is (-> (ds/->dataset [{"a" 1}] {:dataset-name "string name"}) + (tribuo/make-regression-datasource "a")))) + (testing "keyword name (Error)" + (is (-> (ds/->dataset [{"a" 1}] {:dataset-name :keyword/name}) + (tribuo/make-regression-datasource "a"))))) diff --git a/src/tech/v3/libs/tribuo.clj b/src/tech/v3/libs/tribuo.clj index 7abbcc0b..ec2b297a 100644 --- a/src/tech/v3/libs/tribuo.clj +++ b/src/tech/v3/libs/tribuo.clj @@ -55,7 +55,7 @@ _unnamed [5 1]: [org.tribuo.regression.evaluation RegressionEvaluator RegressionEvaluation] [com.oracle.labs.mlrg.olcut.config ConfigurationManager] [com.oracle.labs.mlrg.olcut.config.json JsonConfigFactory])) - + (set! *warn-on-reflection* true) @@ -157,13 +157,22 @@ _unnamed [5 1]: cnames (->double-array (feat-data idx)))) (meta outputs)))) +(defn- safe-str + [n] + (cond (string? n) + n + (or (keyword? n) (symbol? n)) + (if-let [nn (namespace n)] + (str nn "/" (name n)) + (str (name n))))) + (defn- ds->datasource ^DataSource [ds ds->outputs] (let [examples (ds->examples ds ds->outputs) {:keys [output-factory provenance]} (meta examples) provenance (or provenance - (SimpleDataSourceProvenance. (:name (meta ds)) output-factory))] + (SimpleDataSourceProvenance. (safe-str (:name (meta ds))) output-factory))] (when-not output-factory (throw (RuntimeException. "Output factory not present in example metadata"))) (reify DataSource