Skip to content

Commit

Permalink
ShardedIntegrityChecksSparkJob Deduplication Revamp (osmlab#498)
Browse files Browse the repository at this point in the history
* new dedupe

* fix event service

* event service per partition

* debug count

* debug count2

* debug count3

* part id

* spotless

* raw flags

* concurretn queue

* comments

* geojson name
  • Loading branch information
Bentleysb authored Feb 10, 2021
1 parent 5625945 commit 69f79b3
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 157 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
import static org.openstreetmap.atlas.checks.distributed.IntegrityCheckSparkJob.METRICS_FILENAME;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.apache.spark.TaskContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
Expand Down Expand Up @@ -60,7 +65,6 @@
import com.google.common.eventbus.Subscribe;

import scala.Serializable;
import scala.Tuple2;

/**
* A spark job for generating integrity checks in a sharded fashion. This allows for a lower local
Expand Down Expand Up @@ -106,6 +110,7 @@ public String getName()
}

@Override
@SuppressWarnings("unchecked")
public void start(final CommandMap commandMap)
{
final Time start = Time.now();
Expand All @@ -116,12 +121,10 @@ public void start(final CommandMap commandMap)

// Gather arguments
final String output = this.output(commandMap);
@SuppressWarnings("unchecked")
final Set<OutputFormats> outputFormats = (Set<OutputFormats>) commandMap
.get(OUTPUT_FORMATS);
final StringList countries = StringList.split((String) commandMap.get(COUNTRIES),
CommonConstants.COMMA);
@SuppressWarnings("unchecked")
final Optional<List<String>> checkFilter = (Optional<List<String>>) commandMap
.getOption(CHECK_FILTER);

Expand All @@ -148,7 +151,6 @@ public void start(final CommandMap commandMap)
fileFetcher);

// Get sharding
@SuppressWarnings("unchecked")
final Optional<String> alternateShardingFile = (Optional<String>) commandMap
.getOption(SHARDING);
final String shardingPathInAtlas = "dynamic@"
Expand All @@ -160,7 +162,6 @@ public void start(final CommandMap commandMap)
final Distance distanceToLoadShards = (Distance) commandMap.get(EXPANSION_DISTANCE);

// get timeout
@SuppressWarnings("unchecked")
final Optional<Long> alternateMaxPoolMinutes = (Optional<Long>) commandMap
.getOption(MAX_POOL_MINUTES);
final Duration maxPoolDuration = Duration
Expand Down Expand Up @@ -214,24 +215,23 @@ public void start(final CommandMap commandMap)
{
checkPool.queue(() ->
{
final String country = countryShard.getKey();
// Generate a task for each shard
final List<ShardedCheckFlagsTask> tasksForCountry = countryShard.getValue()
.stream()
.map(shard -> new ShardedCheckFlagsTask(countryShard.getKey(), shard,
this.countryChecks.get(countryShard.getKey())))
final List<ShardedCheckFlagsTask> tasksForCountry = countryShard
.getValue().stream().map(shard -> new ShardedCheckFlagsTask(country,
shard, this.countryChecks.get(country)))
.collect(Collectors.toList());

// Set spark UI job title
this.getContext().setLocalProperty("callSite.short", String
.format("Running checks on %s", tasksForCountry.get(0).getCountry()));

this.getContext().parallelize(tasksForCountry, tasksForCountry.size())
.mapToPair(this.produceFlags(input, output, this.configurationMap(),
.flatMap(this.produceFlags(input, output, this.configurationMap(),
fileHelper, shardingBroadcast, distanceToLoadShards,
(Boolean) commandMap.get(MULTI_ATLAS)))
.reduceByKey(UniqueCheckFlagContainer::combine)
// Generate outputs
.foreach(this.processFlags(output, fileHelper, outputFormats));
.distinct().map(UniqueCheckFlagContainer::getEvent).foreachPartition(
this.processFlags(output, fileHelper, outputFormats, country));
});
}
}
Expand Down Expand Up @@ -277,18 +277,19 @@ private Function<Shard, Optional<Atlas>> atlasFetcher(final String input, final
* @param outputFormats
* {@link Set} of
* {@link org.openstreetmap.atlas.checks.distributed.IntegrityChecksCommandArguments.OutputFormats}
* @return {@link VoidFunction} that takes a {@link Tuple2} of a {@link String} country code and
* a {@link UniqueCheckFlagContainer}
* @param country
* {@link String} ISO code for the country being processed
* @return {@link VoidFunction} that takes an {@link Iterator} of {@link CheckFlagEvent}s
*/
@SuppressWarnings("unchecked")
private VoidFunction<Tuple2<String, UniqueCheckFlagContainer>> processFlags(final String output,
final SparkFileHelper fileHelper, final Set<OutputFormats> outputFormats)
private VoidFunction<Iterator<CheckFlagEvent>> processFlags(final String output,
final SparkFileHelper fileHelper, final Set<OutputFormats> outputFormats,
final String country)
{
return tuple ->
return iterator ->
{
final String country = tuple._1();
final UniqueCheckFlagContainer flagContainer = tuple._2();
final EventService<CheckFlagEvent> eventService = EventService.get(country);
final EventService<CheckFlagEvent> eventService = EventService
.get(country + TaskContext.getPartitionId());

if (outputFormats.contains(OutputFormats.FLAGS))
{
Expand All @@ -309,7 +310,7 @@ private VoidFunction<Tuple2<String, UniqueCheckFlagContainer>> processFlags(fina
SparkFileHelper.combine(output, OUTPUT_TIPPECANOE_FOLDER, country)));
}

flagContainer.reconstructEvents().parallel().forEach(eventService::post);
iterator.forEachRemaining(eventService::post);
eventService.complete();
};
}
Expand All @@ -332,11 +333,11 @@ private VoidFunction<Tuple2<String, UniqueCheckFlagContainer>> processFlags(fina
* {@link Distance} to expand the shard group
* @param multiAtlas
* boolean whether to use a multi or dynamic Atlas
* @return {@link PairFunction} that takes {@link ShardedCheckFlagsTask} and returns a
* {@link Tuple2} of a {@link String} country code and {@link UniqueCheckFlagContainer}
* @return {@link FlatMapFunction} that takes {@link ShardedCheckFlagsTask} and returns a
* {@link Iterator} of {@link UniqueCheckFlagContainer}s
*/
@SuppressWarnings("unchecked")
private PairFunction<ShardedCheckFlagsTask, String, UniqueCheckFlagContainer> produceFlags(
private FlatMapFunction<ShardedCheckFlagsTask, UniqueCheckFlagContainer> produceFlags(
final String input, final String output, final Map<String, String> configurationMap,
final SparkFileHelper fileHelper, final Broadcast<Sharding> sharding,
final Distance shardDistanceExpansion, final boolean multiAtlas)
Expand Down Expand Up @@ -376,7 +377,7 @@ private PairFunction<ShardedCheckFlagsTask, String, UniqueCheckFlagContainer> pr

// Prepare the event service
final EventService eventService = task.getEventService();
final UniqueCheckFlagContainer container = new UniqueCheckFlagContainer();
final Queue<UniqueCheckFlagContainer> container = new ConcurrentLinkedQueue<>();
eventService.register(new Processor<CheckFlagEvent>()
{
@Override
Expand All @@ -390,7 +391,7 @@ public void process(final ShutdownEvent event)
@AllowConcurrentEvents
public void process(final CheckFlagEvent event)
{
container.add(event.getCheckName(), event.getCheckFlag().makeComplete());
container.add(new UniqueCheckFlagContainer(event));
}
});
// Metrics are output on a per shard level
Expand All @@ -411,7 +412,7 @@ public void process(final CheckFlagEvent event)
}

eventService.complete();
return new Tuple2<>(task.getCountry(), container);
return container.iterator();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import org.apache.spark.TaskContext;
import org.openstreetmap.atlas.checks.constants.CommonConstants;
import org.openstreetmap.atlas.checks.distributed.GeoJsonPathFilter;
import org.openstreetmap.atlas.event.Processor;
Expand Down Expand Up @@ -178,7 +179,8 @@ protected int computeBatchSize()

protected String getFilename(final String challenge, final int size)
{
return String.format("%s-%s-%s%s", challenge, new Date().getTime(), size,
return String.format("%s-%sP%s-%s%s", challenge, new Date().getTime(),
TaskContext.getPartitionId(), size,
new GeoJsonPathFilter(this.compressOutput).getExtension());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.Date;

import org.apache.spark.TaskContext;
import org.openstreetmap.atlas.checks.distributed.GeoJsonPathFilter;
import org.openstreetmap.atlas.checks.vectortiles.TippecanoeCheckSettings;
import org.openstreetmap.atlas.generator.tools.spark.utilities.SparkFileHelper;
Expand Down Expand Up @@ -64,7 +65,7 @@ public void process(final org.openstreetmap.atlas.event.ShutdownEvent event)
@Override
protected String getFilename()
{
return String.format("%s-%s%s", new Date().getTime(), getCount(),
new GeoJsonPathFilter(doesCompressOutput()).getExtension());
return String.format("%sP%s-%s%s", new Date().getTime(), TaskContext.getPartitionId(),
getCount(), new GeoJsonPathFilter(doesCompressOutput()).getExtension());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import org.apache.spark.TaskContext;
import org.openstreetmap.atlas.checks.distributed.LogFilePathFilter;
import org.openstreetmap.atlas.event.Event;
import org.openstreetmap.atlas.event.Processor;
Expand Down Expand Up @@ -159,8 +160,8 @@ public FileProcessor<T> withCompression(final boolean compress)
*/
protected String getFilename()
{
return String.format("%s-%s%s", new Date().getTime(), this.getCount(),
new LogFilePathFilter(this.compressOutput).getExtension());
return String.format("%sP%s-%s%s", new Date().getTime(), TaskContext.getPartitionId(),
this.getCount(), new LogFilePathFilter(this.compressOutput).getExtension());
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,109 +1,90 @@
package org.openstreetmap.atlas.checks.utility;

import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

import org.openstreetmap.atlas.checks.event.CheckFlagEvent;
import org.openstreetmap.atlas.checks.flag.CheckFlag;

/**
* A container that will deduplicate check flags based on source and unique IDs
* A container used to deduplicate check flags based on checkName and unique IDs
*
* @author jklamer
* @author bbreithaupt
*/
public class UniqueCheckFlagContainer implements Serializable
{

private final ConcurrentHashMap<String, ConcurrentHashMap<Set<String>, CheckFlag>> uniqueFlags;
private String checkName;
private Set<String> uniqueIdentifiers;
private CheckFlag checkFlag;

/**
* Combines to containers. This deduplicates {@link CheckFlag}s by overwiting ones with matching
* sources and IDs.
*
* @param container1
* {@link UniqueCheckFlagContainer}
* @param container2
* {@link UniqueCheckFlagContainer}
* @return merged {@link UniqueCheckFlagContainer}
* @param checkFlagEvent
* {@link CheckFlagEvent}
*/
public static UniqueCheckFlagContainer combine(final UniqueCheckFlagContainer container1,
final UniqueCheckFlagContainer container2)
public UniqueCheckFlagContainer(final CheckFlagEvent checkFlagEvent)
{
container2.uniqueFlags.entrySet()
.forEach(entry -> container1.addAll(entry.getKey(), entry.getValue().values()));
return container1;
this(checkFlagEvent.getCheckName(), checkFlagEvent.getCheckFlag().getUniqueIdentifiers(),
checkFlagEvent.getCheckFlag().makeComplete());
}

public UniqueCheckFlagContainer()
/**
* @param checkName
* {@link String} check name
* @param uniqueIdentifiers
* {@link Set} of {@link String}s from {@link CheckFlag#getUniqueIdentifiers()}
* @param checkFlag
* {@link CheckFlag}
*/
public UniqueCheckFlagContainer(final String checkName, final Set<String> uniqueIdentifiers,
final CheckFlag checkFlag)
{
this.uniqueFlags = new ConcurrentHashMap<>();
this.checkName = checkName;
this.uniqueIdentifiers = uniqueIdentifiers;
this.checkFlag = checkFlag.makeComplete();
}

@SuppressWarnings("s1144")
// Ignore unused constructor warning, this is used for deserialization
private UniqueCheckFlagContainer(
final ConcurrentHashMap<String, ConcurrentHashMap<Set<String>, CheckFlag>> flags)
@Override
public boolean equals(final Object other)
{
this.uniqueFlags = flags;
if (this == other)
{
return true;
}
if (other == null || getClass() != other.getClass())
{
return false;
}
final UniqueCheckFlagContainer that = (UniqueCheckFlagContainer) other;
return Objects.equals(this.checkName, that.checkName)
&& Objects.equals(this.uniqueIdentifiers, that.uniqueIdentifiers);
}

/**
* Add a {@link CheckFlag} to the container based on its source.
*
* @param flagSource
* {@link String} source (check that generated the flag)
* @param flag
* {@link CheckFlag}
*/
public void add(final String flagSource, final CheckFlag flag)
public CheckFlag getCheckFlag()
{
this.uniqueFlags.putIfAbsent(flagSource, new ConcurrentHashMap<>());
final Set<String> uniqueObjectIdentifiers = flag.getUniqueIdentifiers();
this.uniqueFlags.get(flagSource)
.putIfAbsent(uniqueObjectIdentifiers.isEmpty()
? Collections.singleton(flag.getIdentifier())
: uniqueObjectIdentifiers, flag);
return this.checkFlag;
}

/**
* Batch add {@link CheckFlag} from a single source.
*
* @param flagSource
* {@link String} source (check that generated the flags)
* @param flags
* {@link Iterable} of {@link CheckFlag}s
*/
public void addAll(final String flagSource, final Iterable<CheckFlag> flags)
public String getCheckName()
{
flags.forEach(flag -> this.add(flagSource, flag));
return this.checkName;
}

public CheckFlagEvent getEvent()
{
return new CheckFlagEvent(this.checkName, this.checkFlag);
}

/**
* Convert the {@link CheckFlag}s into a {@link Stream} of {@link CheckFlagEvent}s.
*
* @return a {@link Stream} of {@link CheckFlagEvent}s
*/
public Stream<CheckFlagEvent> reconstructEvents()
public Set<String> getUniqueIdentifiers()
{
return this.uniqueFlags.keySet().stream()
.flatMap(checkName -> this.uniqueFlags.get(checkName).values().stream()
.map(checkFlag -> new CheckFlagEvent(checkName, checkFlag)));
return this.uniqueIdentifiers;
}

/**
* Get the contents of the container as a stream.
*
* @return a {@link Stream} of {@link CheckFlag}s
*/
public Stream<CheckFlag> stream()
@Override
public int hashCode()
{
return this.uniqueFlags.values().stream().map(ConcurrentHashMap::values)
.flatMap(Collection::stream);
return Objects.hash(this.checkName, this.uniqueIdentifiers);
}
}
Loading

0 comments on commit 69f79b3

Please sign in to comment.