## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#fromtypingimportClassVar,Type,Dict,List,Optional,Union,castfrompyspark.java_gatewayimportlocal_connect_and_authfrompyspark.resourceimportResourceInformationfrompyspark.serializersimportread_int,write_int,write_with_length,UTF8Deserializer
[docs]classTaskContext:""" Contextual information about a task which can be read or mutated during execution. To access the TaskContext for a running task, use: :meth:`TaskContext.get`. """_taskContext:ClassVar[Optional["TaskContext"]]=None_attemptNumber:Optional[int]=None_partitionId:Optional[int]=None_stageId:Optional[int]=None_taskAttemptId:Optional[int]=None_localProperties:Optional[Dict[str,str]]=None_cpus:Optional[int]=None_resources:Optional[Dict[str,ResourceInformation]]=Nonedef__new__(cls:Type["TaskContext"])->"TaskContext":"""Even if users construct TaskContext instead of using get, give them the singleton."""taskContext=cls._taskContextiftaskContextisnotNone:returntaskContextcls._taskContext=taskContext=object.__new__(cls)returntaskContext@classmethoddef_getOrCreate(cls:Type["TaskContext"])->"TaskContext":"""Internal function to get or create global TaskContext."""ifcls._taskContextisNone:cls._taskContext=TaskContext()returncls._taskContext@classmethoddef_setTaskContext(cls:Type["TaskContext"],taskContext:"TaskContext")->None:cls._taskContext=taskContext
[docs]@classmethoddefget(cls:Type["TaskContext"])->Optional["TaskContext"]:""" Return the currently active TaskContext. This can be called inside of user functions to access contextual information about running tasks. Notes ----- Must be called on the worker, not the driver. Returns None if not initialized. """returncls._taskContext
[docs]defstageId(self)->int:"""The ID of the stage that this task belong to."""returncast(int,self._stageId)
[docs]defpartitionId(self)->int:""" The ID of the RDD partition that is computed by this task. """returncast(int,self._partitionId)
[docs]defattemptNumber(self)->int:""" " How many times this task has been attempted. The first task attempt will be assigned attemptNumber = 0, and subsequent attempts will have increasing attempt numbers. """returncast(int,self._attemptNumber)
[docs]deftaskAttemptId(self)->int:""" An ID that is unique to this task attempt (within the same SparkContext, no two task attempts will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. """returncast(int,self._taskAttemptId)
[docs]defgetLocalProperty(self,key:str)->Optional[str]:""" Get a local property set upstream in the driver, or None if it is missing. """returncast(Dict[str,str],self._localProperties).get(key,None)
defcpus(self)->int:""" CPUs allocated to the task. """returncast(int,self._cpus)
[docs]defresources(self)->Dict[str,ResourceInformation]:""" Resources allocated to the task. The key is the resource name and the value is information about the resource. """returncast(Dict[str,ResourceInformation],self._resources)
BARRIER_FUNCTION=1ALL_GATHER_FUNCTION=2def_load_from_socket(port:Optional[Union[str,int]],auth_secret:str,function:int,all_gather_message:Optional[str]=None,)->List[str]:""" Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """(sockfile,sock)=local_connect_and_auth(port,auth_secret)# The call may block forever, so no timeoutsock.settimeout(None)iffunction==BARRIER_FUNCTION:# Make a barrier() function call.write_int(function,sockfile)eliffunction==ALL_GATHER_FUNCTION:# Make a all_gather() function call.write_int(function,sockfile)write_with_length(cast(str,all_gather_message).encode("utf-8"),sockfile)else:raiseValueError("Unrecognized function type")sockfile.flush()# Collect result.len=read_int(sockfile)res=[]foriinrange(len):res.append(UTF8Deserializer().loads(sockfile))# Release resources.sockfile.close()sock.close()returnres
[docs]classBarrierTaskContext(TaskContext):""" A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage. Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task. .. versionadded:: 2.4.0 Notes ----- This API is experimental """_port:ClassVar[Optional[Union[str,int]]]=None_secret:ClassVar[Optional[str]]=None@classmethoddef_getOrCreate(cls:Type["BarrierTaskContext"])->"BarrierTaskContext":""" Internal function to get or create global BarrierTaskContext. We need to make sure BarrierTaskContext is returned from here because it is needed in python worker reuse scenario, see SPARK-25921 for more details. """ifnotisinstance(cls._taskContext,BarrierTaskContext):cls._taskContext=object.__new__(cls)returncls._taskContext
[docs]@classmethoddefget(cls:Type["BarrierTaskContext"])->"BarrierTaskContext":""" Return the currently active :class:`BarrierTaskContext`. This can be called inside of user functions to access contextual information about running tasks. Notes ----- Must be called on the worker, not the driver. Returns None if not initialized. An Exception will raise if it is not in a barrier stage. This API is experimental """ifnotisinstance(cls._taskContext,BarrierTaskContext):raiseRuntimeError("It is not in a barrier stage")returncls._taskContext
@classmethoddef_initialize(cls:Type["BarrierTaskContext"],port:Optional[Union[str,int]],secret:str)->None:""" Initialize BarrierTaskContext, other methods within BarrierTaskContext can only be called after BarrierTaskContext is initialized. """cls._port=portcls._secret=secret
[docs]defbarrier(self)->None:""" Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks in the same stage have reached this routine. .. versionadded:: 2.4.0 .. warning:: In a barrier stage, each task much have the same number of `barrier()` calls, in all possible code branches. Otherwise, you may get the job hanging or a SparkException after timeout. Notes ----- This API is experimental """ifself._portisNoneorself._secretisNone:raiseRuntimeError("Not supported to call barrier() before initialize "+"BarrierTaskContext.")else:_load_from_socket(self._port,self._secret,BARRIER_FUNCTION)
[docs]defallGather(self,message:str="")->List[str]:""" This function blocks until all tasks in the same stage have reached this routine. Each task passes in a message and returns with a list of all the messages passed in by each of those tasks. .. versionadded:: 3.0.0 .. warning:: In a barrier stage, each task much have the same number of `allGather()` calls, in all possible code branches. Otherwise, you may get the job hanging or a SparkException after timeout. Notes ----- This API is experimental """ifnotisinstance(message,str):raiseTypeError("Argument `message` must be of type `str`")elifself._portisNoneorself._secretisNone:raiseRuntimeError("Not supported to call barrier() before initialize "+"BarrierTaskContext.")else:return_load_from_socket(self._port,self._secret,ALL_GATHER_FUNCTION,message)
[docs]defgetTaskInfos(self)->List["BarrierTaskInfo"]:""" Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage, ordered by partition ID. .. versionadded:: 2.4.0 Notes ----- This API is experimental """ifself._portisNoneorself._secretisNone:raiseRuntimeError("Not supported to call getTaskInfos() before initialize "+"BarrierTaskContext.")else:addresses=cast(Dict[str,str],self._localProperties).get("addresses","")return[BarrierTaskInfo(h.strip())forhinaddresses.split(",")]
[docs]classBarrierTaskInfo:""" Carries all task infos of a barrier task. .. versionadded:: 2.4.0 Attributes ---------- address : str The IPv4 address (host:port) of the executor that the barrier task is running on Notes ----- This API is experimental """def__init__(self,address:str)->None:self.address=address