So I work at a transfer learning example. My task is to take VGG-16, modify and train it using transfer learning only to predict 2 labels ("Cat" and "Dog"). Here's my code, it is not correctly formatted, so don't get distracted by the "messiness":
static DataSetIterator trainIter;
public static void main(String[] args) throws Exception {
SpringApplication.run(AiProjectDl4jApplication.class, args);
int seed = 12345;
int numClasses = 2;
ZooModel zooModel = VGG16.builder().build();
ComputationGraph pretrainedNet = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Nesterovs(5e-5))
.seed(seed)
.build();
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(pretrainedNet)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor("fc2")
.removeVertexKeepConnections("predictions")
.addLayer("predictions",
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(4096).nOut(numClasses)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build(), "fc2")
.build();
TransferLearningHelper transferLearningHelper =
new TransferLearningHelper(vgg16Transfer, "fc2");
getImages();
DataSetPreProcessor preProcessor = new VGG16ImagePreProcessor() ;
trainIter.setPreProcessor(preProcessor);
vgg16Transfer.setListeners(new ScoreIterationListener(5));
log.info("Training starting...");
for (int i = 0; i < 5; i++) {
while (trainIter.hasNext()) {
DataSet currentFeaturized = trainIter.next();
vgg16Transfer.fit(currentFeaturized);
}
}
log.info("If no results shown -> Summary: Failed");
}
static void getImages() throws IOException {
Random rng = new Random();
File parentDir = new File("Path To DataSet");
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(224,224,3,labelMaker);
recordReader.initialize(new FileSplit(new File(String.valueOf(parentDir))));
trainIter = new RecordReaderDataSetIterator(recordReader,150,1,2);
}
}
The problem lies where I try to train the modified model. I get the results only for the first iteration (that is already pretty high, at 0.9171) and then get this error:
# A fatal error has been detected by the Java Runtime Environment:
#
# EXCEPTION_ACCESS_VIOLATION (0xc0000005) at pc=0x00007ffb426a3b29, pid=11708, tid=0x0000000000001a78
#
# JRE version: Java(TM) SE Runtime Environment (8.0_251-b08) (build 1.8.0_251-b08)
# Java VM: Java HotSpot(TM) 64-Bit Server VM (25.251-b08 mixed mode windows-amd64 compressed oops)
# Problematic frame:
# C [KERNELBASE.dll+0x43b29]
#
# Failed to write core dump. Minidumps are not enabled by default on client versions of Windows
I have already tried to build and train my own network a couple months ago and had the same problem that nobody could really help me with. This time i hope you will. Do you have any idea what could cause it?
User contributions licensed under CC BY-SA 3.0