diff --git a/src/java/com/cloudera/sqoop/orm/ClassWriter.java b/src/java/com/cloudera/sqoop/orm/ClassWriter.java index fc14a55a..b5b8a12d 100644 --- a/src/java/com/cloudera/sqoop/orm/ClassWriter.java +++ b/src/java/com/cloudera/sqoop/orm/ClassWriter.java @@ -426,13 +426,14 @@ private String rpcSetterForMaybeNull(String javaType, String outputObj, } /** - * Generate a member field and getter method for each column. + * Generate a member field, getter, setter and with method for each column. * @param columnTypes - mapping from column names to sql types - * @param colNames - ordered list of column names for table. + * @param colNames - ordered list of column names for table + * @param className - name of the generated class * @param sb - StringBuilder to append code to */ private void generateFields(Map columnTypes, - String [] colNames, StringBuilder sb) { + String [] colNames, String className, StringBuilder sb) { for (String col : colNames) { int sqlType = columnTypes.get(col); @@ -446,9 +447,51 @@ private void generateFields(Map columnTypes, sb.append(" public " + javaType + " get_" + col + "() {\n"); sb.append(" return " + col + ";\n"); sb.append(" }\n"); + sb.append(" public void set_" + col + "(" + javaType + " " + col + + ") {\n"); + sb.append(" this." + col + " = " + col + ";\n"); + sb.append(" }\n"); + sb.append(" public " + className + " with_" + col + "(" + javaType + " " + + col + ") {\n"); + sb.append(" this." + col + " = " + col + ";\n"); + sb.append(" return this;\n"); + sb.append(" }\n"); } } + /** + * Generate an equals method that compares the fields for each column. + * @param columnTypes - mapping from column names to sql types + * @param colNames - ordered list of column names for table + * @param className - name of the generated class + * @param sb - StringBuilder to append code to + */ + private void generateEquals(Map columnTypes, + String [] colNames, String className, StringBuilder sb) { + + sb.append(" public boolean equals(Object o) {\n"); + sb.append(" if (this == o) {\n"); + sb.append(" return true;\n"); + sb.append(" }\n"); + sb.append(" if (!(o instanceof " + className + ")) {\n"); + sb.append(" return false;\n"); + sb.append(" }\n"); + sb.append(" " + className + " that = (" + className + ") o;\n"); + sb.append(" boolean equal = true;\n"); + for (String col : colNames) { + int sqlType = columnTypes.get(col); + String javaType = connManager.toJavaType(sqlType); + if (null == javaType) { + LOG.error("Cannot resolve SQL type " + sqlType); + continue; + } + sb.append(" equal = equal && (this." + col + " == null ? that." + col + + " == null : this." + col + ".equals(that." + col + "));\n"); + } + sb.append(" return equal;\n"); + sb.append(" }\n"); + } + /** * Generate the readFields() method used by the database. * @param columnTypes - mapping from column names to sql types @@ -1180,7 +1223,8 @@ private StringBuilder generateClassForColumns( sb.append( " public int getClassFormatVersion() { return PROTOCOL_VERSION; }\n"); sb.append(" protected ResultSet __cur_result_set;\n"); - generateFields(columnTypes, colNames, sb); + generateFields(columnTypes, colNames, className, sb); + generateEquals(columnTypes, colNames, className, sb); generateDbRead(columnTypes, colNames, sb); generateLoadLargeObjects(columnTypes, colNames, sb); generateDbWrite(columnTypes, dbWriteColNames, sb); diff --git a/src/test/com/cloudera/sqoop/orm/TestClassWriter.java b/src/test/com/cloudera/sqoop/orm/TestClassWriter.java index 63d4ad94..23dbd10c 100644 --- a/src/test/com/cloudera/sqoop/orm/TestClassWriter.java +++ b/src/test/com/cloudera/sqoop/orm/TestClassWriter.java @@ -21,6 +21,8 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.sql.Connection; import java.sql.Statement; import java.sql.SQLException; @@ -41,6 +43,7 @@ import com.cloudera.sqoop.testutil.HsqldbTestServer; import com.cloudera.sqoop.testutil.ImportJobTestCase; import com.cloudera.sqoop.tool.ImportTool; +import com.cloudera.sqoop.util.ClassLoaderStack; /** * Test that the ClassWriter generates Java classes based on the given table, @@ -118,8 +121,9 @@ public void tearDown() { /** * Run a test to verify that we can generate code and it emits the output * files where we expect them. + * @return */ - private void runGenerationTest(String [] argv, String classNameToCheck) { + private File runGenerationTest(String [] argv, String classNameToCheck) { File codeGenDirFile = new File(CODE_GEN_DIR); File classGenDirFile = new File(JAR_GEN_DIR); @@ -164,7 +168,7 @@ private void runGenerationTest(String [] argv, String classNameToCheck) { assertTrue("Cannot find compiled jar", jarFile.exists()); LOG.debug("Found generated jar: " + jarFile); - // check that the .class file made it into the .jar by enumerating + // check that the .class file made it into the .jar by enumerating // available entries in the jar file. boolean foundCompiledClass = false; try { @@ -195,6 +199,7 @@ private void runGenerationTest(String [] argv, String classNameToCheck) { + ".class in jar file", foundCompiledClass); LOG.debug("Found class in jar - test success!"); + return jarFile; } /** @@ -258,7 +263,7 @@ public void testSetClassAndPackageName() { runGenerationTest(argv, OVERRIDE_CLASS_AND_PACKAGE_NAME); } - + private static final String OVERRIDE_PACKAGE_NAME = "special.userpackage.name"; @@ -334,5 +339,87 @@ public void testWeirdColumnNames() throws SQLException { runGenerationTest(argv, OVERRIDE_PACKAGE_NAME + "." + HsqldbTestServer.getTableName()); } -} + /** + * Test the generated equals method. + * @throws IOException + * @throws ClassNotFoundException + * @throws IllegalAccessException + * @throws InstantiationException + * @throws NoSuchMethodException + * @throws SecurityException + * @throws InvocationTargetException + * @throws IllegalArgumentException + */ + @Test + public void testEqualsMethod() throws IOException, ClassNotFoundException, + InstantiationException, IllegalAccessException, NoSuchMethodException, + InvocationTargetException { + + // Set the option strings in an "argv" to redirect our srcdir and bindir + String [] argv = { + "--bindir", + JAR_GEN_DIR, + "--outdir", + CODE_GEN_DIR, + "--class-name", + OVERRIDE_CLASS_AND_PACKAGE_NAME, + }; + + File ormJarFile = runGenerationTest(argv, OVERRIDE_CLASS_AND_PACKAGE_NAME); + ClassLoader prevClassLoader = ClassLoaderStack.addJarFile( + ormJarFile.getCanonicalPath(), + OVERRIDE_CLASS_AND_PACKAGE_NAME); + Class tableClass = Class.forName( + OVERRIDE_CLASS_AND_PACKAGE_NAME, + true, + Thread.currentThread().getContextClassLoader()); + Method setterIntField1 = + tableClass.getMethod("set_INTFIELD1", Integer.class); + Method setterIntField2 = + tableClass.getMethod("set_INTFIELD2", Integer.class); + Method equalsImplementation = tableClass.getMethod("equals", Object.class); + + Object instance1 = tableClass.newInstance(); + Object instance2 = tableClass.newInstance(); + + // test reflexivity + assertTrue((Boolean) equalsImplementation.invoke(instance1, instance1)); + + // test equality for uninitialized fields + assertTrue((Boolean) equalsImplementation.invoke(instance1, instance2)); + + // test symmetry + assertTrue((Boolean) equalsImplementation.invoke(instance2, instance1)); + + // test reflexivity with initialized fields + setterIntField1.invoke(instance1, new Integer(1)); + setterIntField2.invoke(instance1, new Integer(2)); + assertTrue((Boolean) equalsImplementation.invoke(instance1, instance1)); + + // test difference in both fields + setterIntField1.invoke(instance2, new Integer(3)); + setterIntField2.invoke(instance2, new Integer(4)); + assertFalse((Boolean) equalsImplementation.invoke(instance1, instance2)); + + // test difference in second field + setterIntField1.invoke(instance2, new Integer(1)); + setterIntField2.invoke(instance2, new Integer(3)); + assertFalse((Boolean) equalsImplementation.invoke(instance1, instance2)); + + // test difference in first field + setterIntField1.invoke(instance2, new Integer(3)); + setterIntField2.invoke(instance2, new Integer(2)); + assertFalse((Boolean) equalsImplementation.invoke(instance1, instance2)); + + // test equality for initialized fields + setterIntField1.invoke(instance2, new Integer(1)); + setterIntField2.invoke(instance2, new Integer(2)); + assertTrue((Boolean) equalsImplementation.invoke(instance1, instance2)); + + if (null != prevClassLoader) { + ClassLoaderStack.setCurrentClassLoader(prevClassLoader); + } + } + +}