Skip to content

Commit e06d16a

Browse files
committed
update
1 parent 9003396 commit e06d16a

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

clm-apr/codegen/defects4j_codegen.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def defects4j_codegen_output(input_file, output_file, num_output=10):
139139

140140
# 生成多个输出
141141
outputs = []
142-
for _ in range(num_output):
142+
for i in range(num_output):
143+
print(f"Generating output {i+1}/{num_output}")
143144
# 构建 prompt
144145
prompt = f"[INST] This is an incorrect code({filename}):\n```java\n{code}\n```\nYou are a software engineer. Can you repair the incorrect code?\n[/INST]\n```java\n"
145146
print(f"Prompt: {prompt}", flush=True)
@@ -149,30 +150,36 @@ def defects4j_codegen_output(input_file, output_file, num_output=10):
149150
max_d = 500 - cnt
150151

151152
while True:
152-
response = pipe(
153-
prompt,
154-
min_length=cnt+64,
155-
max_length=cnt+max_d,
156-
temperature=1.0,
157-
do_sample=True
158-
)[0]['generated_text']
159-
160-
# 提取生成的代码
161153
try:
162-
print('response:', response)
163-
generated_code = response.split('[/INST]')[1].strip()
164-
if generated_code.startswith('```java\n'):
165-
generated_code = generated_code[8:]
166-
if generated_code.endswith('```'):
167-
generated_code = generated_code[:-3]
168-
generated_code = generated_code.strip()
154+
response = pipe(
155+
prompt,
156+
min_length=cnt+64,
157+
max_length=cnt+max_d,
158+
temperature=1.0,
159+
do_sample=True
160+
)[0]['generated_text']
161+
162+
print('Full response:', response)
163+
164+
# 提取生成的代码
165+
if '[/INST]' in response:
166+
generated_code = response.split('[/INST]')[1].strip()
167+
if generated_code.startswith('```java\n'):
168+
generated_code = generated_code[8:]
169+
if generated_code.endswith('```'):
170+
generated_code = generated_code[:-3]
171+
generated_code = generated_code.strip()
172+
173+
if generated_code and not generated_code.isspace():
174+
print(f"Generated code: {generated_code}")
175+
outputs.append(generated_code)
176+
break
169177

170-
if generated_code != '':
171-
break
172178
except Exception as e:
173-
print(f"Error processing {filename}: {e}")
179+
print(f"Error in generation: {e}")
174180
pass
175181

182+
print(f"Retrying with increased length...")
176183
max_d = min(2000 - cnt, max_d + 500)
177184

178185
outputs.append(generated_code)

0 commit comments

Comments
 (0)