## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromtypingimportAny,Union,List,Tupleimportnumpyasnpimportpandasaspdfrompysparkimportkeyword_onlyfrompyspark.ml.paramimportParam,Params,TypeConvertersfrompyspark.ml.param.sharedimportHasLabelCol,HasPredictionCol,HasProbabilityColfrompyspark.ml.connect.baseimportEvaluatorfrompyspark.ml.connect.io_utilsimportParamsReadWritefrompyspark.ml.connect.utilimportaggregate_dataframefrompyspark.sqlimportDataFrameclass_TorchMetricEvaluator(Evaluator):metricName:Param[str]=Param(Params._dummy(),"metricName","metric name for the regression evaluator, valid values are 'mse' and 'r2'",typeConverter=TypeConverters.toString,)defgetMetricName(self)->str:""" Gets the value of metricName or its default value. .. versionadded:: 3.5.0 """returnself.getOrDefault(self.metricName)def_get_torch_metric(self)->Any:raiseNotImplementedError()def_get_input_cols(self)->List[str]:raiseNotImplementedError()def_get_metric_update_inputs(self,dataset:"pd.DataFrame")->Tuple[Any,Any]:raiseNotImplementedError()def_evaluate(self,dataset:Union["DataFrame","pd.DataFrame"])->float:torch_metric=self._get_torch_metric()deflocal_agg_fn(pandas_df:"pd.DataFrame")->"pd.DataFrame":torch_metric.update(*self._get_metric_update_inputs(pandas_df))returntorch_metricdefmerge_agg_state(state1:Any,state2:Any)->Any:state1.merge_state([state2])returnstate1defagg_state_to_result(state:Any)->Any:returnstate.compute().item()returnaggregate_dataframe(dataset,self._get_input_cols(),local_agg_fn,merge_agg_state,agg_state_to_result,)def_get_rmse_torchmetric()->Any:importtorchimporttorcheval.metricsastorchmetricsclass_RootMeanSquaredError(torchmetrics.MeanSquaredError):defcompute(self:Any)->torch.Tensor:returntorch.sqrt(super().compute())return_RootMeanSquaredError()
[docs]classRegressionEvaluator(_TorchMetricEvaluator,HasLabelCol,HasPredictionCol,ParamsReadWrite):""" Evaluator for Regression, which expects input columns prediction and label. Supported metrics are 'rmse', 'mse' and 'r2'. .. versionadded:: 3.5.0 Examples -------- >>> from pyspark.ml.connect.evaluation import RegressionEvaluator >>> eva = RegressionEvaluator(metricName='mse') >>> dataset = spark.createDataFrame( ... [(1.0, 2.0), (-1.0, -1.5)], schema=['label', 'prediction'] ... ) >>> eva.evaluate(dataset) 0.625 >>> eva.isLargerBetter() False """@keyword_onlydef__init__(self,*,metricName:str="rmse",labelCol:str="label",predictionCol:str="prediction",)->None:""" __init__(self, *, metricName='rmse', labelCol='label', predictionCol='prediction') -> None: """super().__init__()self._set(metricName=metricName,labelCol=labelCol,predictionCol=predictionCol)def_get_torch_metric(self)->Any:importtorcheval.metricsastorchmetricsmetric_name=self.getOrDefault(self.metricName)ifmetric_name=="mse":returntorchmetrics.MeanSquaredError()ifmetric_name=="r2":returntorchmetrics.R2Score()ifmetric_name=="rmse":return_get_rmse_torchmetric()raiseValueError(f"Unsupported regressor evaluator metric name: {metric_name}")def_get_input_cols(self)->List[str]:return[self.getPredictionCol(),self.getLabelCol()]def_get_metric_update_inputs(self,dataset:"pd.DataFrame")->Tuple[Any,Any]:importtorchpreds_tensor=torch.tensor(dataset[self.getPredictionCol()].values)labels_tensor=torch.tensor(dataset[self.getLabelCol()].values)returnpreds_tensor,labels_tensor